Spaces:
Running
Running
Upload 19 files
Browse files- checkpoints/model.pth +3 -0
- config/.DS_Store +0 -0
- config/model_config.py +23 -0
- inference.py +55 -0
- requirements.txt +2 -0
- src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- src/data/__pycache__/tokenizer.cpython-310.pyc +0 -0
- src/data/__pycache__/tokenizer.cpython-37.pyc +0 -0
- src/data/dataset.py +39 -0
- src/data/tokenizer.py +12 -0
- src/model/__pycache__/gpt.cpython-310.pyc +0 -0
- src/model/__pycache__/gpt.cpython-37.pyc +0 -0
- src/model/gpt.py +114 -0
- src/training/__pycache__/trainer.cpython-310.pyc +0 -0
- src/training/trainer.py +61 -0
- src/utils/__pycache__/helpers.cpython-310.pyc +0 -0
- src/utils/__pycache__/helpers.cpython-37.pyc +0 -0
- src/utils/helpers.py +71 -0
- train.py +60 -0
checkpoints/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:231488f17b10270ae927e0c778b1887705bf890ed4d76422ae768c656ed4a44e
|
3 |
+
size 43421418
|
config/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
config/model_config.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class ModelConfig:
|
6 |
+
# Model architecture
|
7 |
+
n_embeds: int = 384
|
8 |
+
n_heads: int = 6
|
9 |
+
n_layers: int = 6
|
10 |
+
dropout: float = 0.3
|
11 |
+
|
12 |
+
# Training
|
13 |
+
batch_size: int = 64
|
14 |
+
block_size: int = 128
|
15 |
+
max_iters: int = 20000
|
16 |
+
eval_interval: int = 250
|
17 |
+
eval_iters: int = 200
|
18 |
+
learning_rate: float = 3e-4
|
19 |
+
weight_decay: float = 0.1
|
20 |
+
|
21 |
+
# Paths
|
22 |
+
checkpoint_path: str = "checkpoints/model.pth"
|
23 |
+
data_path: str = "/data/nikhil_workspace/assn11/input.txt"
|
inference.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from config.model_config import ModelConfig
|
3 |
+
from src.data.tokenizer import CharacterTokenizer
|
4 |
+
from src.utils.helpers import generate, setup_logging
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
# Setup logging
|
9 |
+
logger = setup_logging()
|
10 |
+
|
11 |
+
# Load config
|
12 |
+
config = ModelConfig()
|
13 |
+
|
14 |
+
# Setup device
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
logger.info(f"Using device: {device}")
|
17 |
+
|
18 |
+
# Load tokenizer
|
19 |
+
with open(config.data_path) as f:
|
20 |
+
text = f.read()
|
21 |
+
tokenizer = CharacterTokenizer(text)
|
22 |
+
|
23 |
+
# Load trained model
|
24 |
+
try:
|
25 |
+
model = torch.load(config.checkpoint_path, map_location=device)
|
26 |
+
model.eval()
|
27 |
+
except Exception as e:
|
28 |
+
logger.error(f"Error loading model: {e}")
|
29 |
+
return
|
30 |
+
|
31 |
+
# Generate text from prompts
|
32 |
+
while True:
|
33 |
+
try:
|
34 |
+
prompt = input("\nEnter a prompt (or 'quit' to exit): ")
|
35 |
+
if prompt.lower() == "quit":
|
36 |
+
break
|
37 |
+
|
38 |
+
max_tokens = 200
|
39 |
+
|
40 |
+
logger.info("\nGenerating...")
|
41 |
+
result = generate(model, tokenizer, prompt, max_tokens, device)
|
42 |
+
logger.info("\nGenerated text:")
|
43 |
+
logger.info("=" * 50)
|
44 |
+
logger.info(prompt + result)
|
45 |
+
logger.info("=" * 50)
|
46 |
+
except KeyboardInterrupt:
|
47 |
+
logger.info("\nExiting...")
|
48 |
+
break
|
49 |
+
except Exception as e:
|
50 |
+
logger.error(f"Error during generation: {e}")
|
51 |
+
continue
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
gradio>=3.50.0
|
src/data/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (1.4 kB). View file
|
|
src/data/__pycache__/tokenizer.cpython-310.pyc
ADDED
Binary file (1.43 kB). View file
|
|
src/data/__pycache__/tokenizer.cpython-37.pyc
ADDED
Binary file (1.43 kB). View file
|
|
src/data/dataset.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
|
4 |
+
class TextDataset(Dataset):
|
5 |
+
def __init__(self, data, block_size):
|
6 |
+
self.data = data
|
7 |
+
self.block_size = block_size
|
8 |
+
|
9 |
+
def __len__(self):
|
10 |
+
return len(self.data) - self.block_size
|
11 |
+
|
12 |
+
def __getitem__(self, idx):
|
13 |
+
x = self.data[idx:idx + self.block_size]
|
14 |
+
y = self.data[idx + 1:idx + self.block_size + 1]
|
15 |
+
return x, y
|
16 |
+
|
17 |
+
def create_dataloaders(text, tokenizer, config, device):
|
18 |
+
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
|
19 |
+
n = int(0.9 * len(data))
|
20 |
+
train_data = data[:n]
|
21 |
+
val_data = data[n:]
|
22 |
+
|
23 |
+
train_dataset = TextDataset(train_data, config.block_size)
|
24 |
+
val_dataset = TextDataset(val_data, config.block_size)
|
25 |
+
|
26 |
+
train_loader = DataLoader(
|
27 |
+
train_dataset,
|
28 |
+
batch_size=config.batch_size,
|
29 |
+
shuffle=True,
|
30 |
+
pin_memory=True
|
31 |
+
)
|
32 |
+
val_loader = DataLoader(
|
33 |
+
val_dataset,
|
34 |
+
batch_size=config.batch_size,
|
35 |
+
shuffle=False,
|
36 |
+
pin_memory=True
|
37 |
+
)
|
38 |
+
|
39 |
+
return train_loader, val_loader
|
src/data/tokenizer.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class CharacterTokenizer:
|
2 |
+
def __init__(self, text):
|
3 |
+
chars = sorted(list(set(text)))
|
4 |
+
self.vocab_size = len(chars)
|
5 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
6 |
+
self.itos = {i: ch for i, ch in enumerate(chars)}
|
7 |
+
|
8 |
+
def encode(self, s):
|
9 |
+
return [self.stoi[c] for c in s]
|
10 |
+
|
11 |
+
def decode(self, l):
|
12 |
+
return "".join([self.itos[i] for i in l])
|
src/model/__pycache__/gpt.cpython-310.pyc
ADDED
Binary file (3.71 kB). View file
|
|
src/model/__pycache__/gpt.cpython-37.pyc
ADDED
Binary file (3.82 kB). View file
|
|
src/model/gpt.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class GPTModel(nn.Module):
|
6 |
+
def __init__(self, config, vocab_size):
|
7 |
+
super().__init__()
|
8 |
+
self.config = config
|
9 |
+
|
10 |
+
self.token_embedding = nn.Embedding(vocab_size, config.n_embeds)
|
11 |
+
self.position_embedding = nn.Embedding(config.block_size, config.n_embeds)
|
12 |
+
self.blocks = nn.ModuleList([
|
13 |
+
TransformerBlock(config) for _ in range(config.n_layers)
|
14 |
+
])
|
15 |
+
self.ln_f = nn.LayerNorm(config.n_embeds)
|
16 |
+
self.lm_head = nn.Linear(config.n_embeds, vocab_size)
|
17 |
+
|
18 |
+
def forward(self, idx, targets=None):
|
19 |
+
B, T = idx.shape
|
20 |
+
|
21 |
+
tok_emb = self.token_embedding(idx)
|
22 |
+
pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
|
23 |
+
x = tok_emb + pos_emb
|
24 |
+
|
25 |
+
for block in self.blocks:
|
26 |
+
x = block(x)
|
27 |
+
x = self.ln_f(x)
|
28 |
+
logits = self.lm_head(x)
|
29 |
+
|
30 |
+
if targets is None:
|
31 |
+
loss = None
|
32 |
+
else:
|
33 |
+
B, T, C = logits.shape
|
34 |
+
logits = logits.view(B*T, C)
|
35 |
+
targets = targets.view(B*T)
|
36 |
+
loss = F.cross_entropy(logits, targets)
|
37 |
+
|
38 |
+
return logits, loss
|
39 |
+
|
40 |
+
class TransformerBlock(nn.Module):
|
41 |
+
def __init__(self, config):
|
42 |
+
super().__init__()
|
43 |
+
self.ln1 = nn.LayerNorm(config.n_embeds)
|
44 |
+
self.ln2 = nn.LayerNorm(config.n_embeds)
|
45 |
+
self.attn = MultiHeadAttention(config)
|
46 |
+
self.mlp = FeedForward(config)
|
47 |
+
self.dropout = nn.Dropout(config.dropout)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
# Self-attention with residual connection
|
51 |
+
x = x + self.dropout(self.attn(self.ln1(x)))
|
52 |
+
# FFN with residual connection
|
53 |
+
x = x + self.dropout(self.mlp(self.ln2(x)))
|
54 |
+
return x
|
55 |
+
|
56 |
+
class MultiHeadAttention(nn.Module):
|
57 |
+
def __init__(self, config):
|
58 |
+
super().__init__()
|
59 |
+
self.n_heads = config.n_heads
|
60 |
+
self.head_size = config.n_embeds // config.n_heads
|
61 |
+
self.n_embeds = config.n_embeds
|
62 |
+
|
63 |
+
# Single linear layer for Q, K, V projections
|
64 |
+
self.c_attn = nn.Linear(config.n_embeds, 3 * config.n_embeds)
|
65 |
+
self.c_proj = nn.Linear(config.n_embeds, config.n_embeds)
|
66 |
+
self.dropout = nn.Dropout(config.dropout)
|
67 |
+
|
68 |
+
# Causal mask to prevent attending to future tokens
|
69 |
+
self.register_buffer(
|
70 |
+
"mask",
|
71 |
+
torch.tril(torch.ones(config.block_size, config.block_size))
|
72 |
+
.view(1, 1, config.block_size, config.block_size)
|
73 |
+
)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
B, T, C = x.shape
|
77 |
+
|
78 |
+
# Calculate Q, K, V with a single linear projection
|
79 |
+
q, k, v = self.c_attn(x).split(self.n_embeds, dim=2)
|
80 |
+
|
81 |
+
# Reshape to (B, nh, T, hs)
|
82 |
+
q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
|
83 |
+
k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
|
84 |
+
v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
|
85 |
+
|
86 |
+
# Compute attention scores
|
87 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(self.head_size)))
|
88 |
+
|
89 |
+
# Apply causal mask
|
90 |
+
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
91 |
+
att = F.softmax(att, dim=-1)
|
92 |
+
att = self.dropout(att)
|
93 |
+
|
94 |
+
# Apply attention to values
|
95 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
96 |
+
|
97 |
+
# Reshape and project back
|
98 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
|
99 |
+
y = self.c_proj(y)
|
100 |
+
|
101 |
+
return y
|
102 |
+
|
103 |
+
class FeedForward(nn.Module):
|
104 |
+
def __init__(self, config):
|
105 |
+
super().__init__()
|
106 |
+
self.net = nn.Sequential(
|
107 |
+
nn.Linear(config.n_embeds, 4 * config.n_embeds),
|
108 |
+
nn.GELU(),
|
109 |
+
nn.Linear(4 * config.n_embeds, config.n_embeds),
|
110 |
+
nn.Dropout(config.dropout),
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
return self.net(x)
|
src/training/__pycache__/trainer.cpython-310.pyc
ADDED
Binary file (1.42 kB). View file
|
|
src/training/trainer.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
from src.utils.helpers import get_batch
|
4 |
+
|
5 |
+
@torch.no_grad()
|
6 |
+
def estimate_loss(model, eval_iters, block_size, batch_size, device):
|
7 |
+
out = {}
|
8 |
+
model.eval()
|
9 |
+
for split in ['train', 'val']:
|
10 |
+
losses = torch.zeros(eval_iters)
|
11 |
+
for k in range(eval_iters):
|
12 |
+
xb, yb = get_batch(split, block_size, batch_size)
|
13 |
+
xb, yb = xb.to(device), yb.to(device)
|
14 |
+
logits, loss = model(xb, yb)
|
15 |
+
losses[k] = loss.item()
|
16 |
+
out[split] = losses.mean().item()
|
17 |
+
model.train()
|
18 |
+
return out
|
19 |
+
|
20 |
+
def train(
|
21 |
+
model,
|
22 |
+
optimizer,
|
23 |
+
max_iters,
|
24 |
+
eval_interval,
|
25 |
+
eval_iters,
|
26 |
+
block_size,
|
27 |
+
batch_size,
|
28 |
+
device,
|
29 |
+
checkpoint_path="checkpoints/model.pth"
|
30 |
+
):
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
best_val_loss = float('inf')
|
33 |
+
|
34 |
+
for iter in range(max_iters):
|
35 |
+
# Evaluation
|
36 |
+
if iter % eval_interval == 0:
|
37 |
+
losses = estimate_loss(model, eval_iters, block_size, batch_size, device)
|
38 |
+
logger.info(
|
39 |
+
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
40 |
+
)
|
41 |
+
|
42 |
+
# Save best model
|
43 |
+
if losses['val'] < best_val_loss:
|
44 |
+
best_val_loss = losses['val']
|
45 |
+
logger.info(f"Saving model with val loss: {best_val_loss:.4f}")
|
46 |
+
torch.save(model, checkpoint_path)
|
47 |
+
|
48 |
+
# Training step
|
49 |
+
xb, yb = get_batch('train', block_size, batch_size)
|
50 |
+
xb, yb = xb.to(device), yb.to(device)
|
51 |
+
|
52 |
+
# Forward pass
|
53 |
+
logits, loss = model(xb, yb)
|
54 |
+
|
55 |
+
# Backward pass
|
56 |
+
optimizer.zero_grad(set_to_none=True)
|
57 |
+
loss.backward()
|
58 |
+
optimizer.step()
|
59 |
+
|
60 |
+
# Save final model
|
61 |
+
torch.save(model, checkpoint_path)
|
src/utils/__pycache__/helpers.cpython-310.pyc
ADDED
Binary file (2.52 kB). View file
|
|
src/utils/__pycache__/helpers.cpython-37.pyc
ADDED
Binary file (1.09 kB). View file
|
|
src/utils/helpers.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
# Global variables for data
|
7 |
+
train_data = None
|
8 |
+
val_data = None
|
9 |
+
|
10 |
+
|
11 |
+
def setup_logging(log_dir="logs"):
|
12 |
+
# Create logs directory if it doesn't exist
|
13 |
+
os.makedirs(log_dir, exist_ok=True)
|
14 |
+
|
15 |
+
# Create a timestamp for the log file
|
16 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
17 |
+
log_file = os.path.join(log_dir, f"training_{timestamp}.log")
|
18 |
+
|
19 |
+
# Configure logging
|
20 |
+
logging.basicConfig(
|
21 |
+
level=logging.INFO,
|
22 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
23 |
+
handlers=[
|
24 |
+
logging.FileHandler(log_file),
|
25 |
+
logging.StreamHandler(), # Also print to console
|
26 |
+
],
|
27 |
+
)
|
28 |
+
|
29 |
+
logging.info(f"Logging to {log_file}")
|
30 |
+
return logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
def count_parameters(model):
|
34 |
+
return sum(p.numel() for p in model.parameters())
|
35 |
+
|
36 |
+
|
37 |
+
def get_batch(split, block_size, batch_size):
|
38 |
+
data = train_data if split == "train" else val_data
|
39 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
40 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
41 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
42 |
+
return x, y
|
43 |
+
|
44 |
+
|
45 |
+
def prepare_data(text, tokenizer):
|
46 |
+
"""Prepare train and validation data"""
|
47 |
+
global train_data, val_data
|
48 |
+
|
49 |
+
# Encode the text
|
50 |
+
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
|
51 |
+
|
52 |
+
# Split into train and validation sets
|
53 |
+
n = int(0.9 * len(data))
|
54 |
+
train_data = data[:n]
|
55 |
+
val_data = data[n:]
|
56 |
+
|
57 |
+
|
58 |
+
def generate(model, tokenizer, prompt, max_tokens, device):
|
59 |
+
model.eval()
|
60 |
+
tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long)[None].to(device)
|
61 |
+
block_size = model.config.block_size
|
62 |
+
|
63 |
+
for _ in range(max_tokens):
|
64 |
+
with torch.no_grad():
|
65 |
+
logits, _ = model(tokens[:, -block_size:])
|
66 |
+
logits = logits[:, -1, :] # / temperature
|
67 |
+
probs = torch.softmax(logits, dim=-1)
|
68 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
69 |
+
tokens = torch.cat([tokens, next_token], dim=1)
|
70 |
+
|
71 |
+
return tokenizer.decode(tokens[0].tolist())[len(prompt) :]
|
train.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from config.model_config import ModelConfig
|
3 |
+
from src.data.tokenizer import CharacterTokenizer
|
4 |
+
from src.model.gpt import GPTModel
|
5 |
+
from src.training.trainer import train
|
6 |
+
from src.utils.helpers import generate, setup_logging, prepare_data
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
# Setup logging
|
11 |
+
logger = setup_logging()
|
12 |
+
|
13 |
+
# Load config
|
14 |
+
config = ModelConfig()
|
15 |
+
|
16 |
+
# Setup device
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
logger.info(f"Using device: {device}")
|
19 |
+
|
20 |
+
# Load data
|
21 |
+
with open(config.data_path) as f:
|
22 |
+
text = f.read()
|
23 |
+
tokenizer = CharacterTokenizer(text)
|
24 |
+
|
25 |
+
# Prepare data
|
26 |
+
prepare_data(text, tokenizer)
|
27 |
+
|
28 |
+
# Create model
|
29 |
+
model = GPTModel(config, tokenizer.vocab_size)
|
30 |
+
model = model.to(device)
|
31 |
+
|
32 |
+
# Setup optimizer
|
33 |
+
optimizer = torch.optim.AdamW(
|
34 |
+
model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay
|
35 |
+
)
|
36 |
+
|
37 |
+
# Train
|
38 |
+
train(
|
39 |
+
model=model,
|
40 |
+
optimizer=optimizer,
|
41 |
+
max_iters=config.max_iters,
|
42 |
+
eval_interval=config.eval_interval,
|
43 |
+
eval_iters=config.eval_iters,
|
44 |
+
block_size=config.block_size,
|
45 |
+
batch_size=config.batch_size,
|
46 |
+
device=device,
|
47 |
+
checkpoint_path=config.checkpoint_path,
|
48 |
+
)
|
49 |
+
|
50 |
+
# Generate samples
|
51 |
+
model = torch.load(config.checkpoint_path, map_location=device)
|
52 |
+
for prompt in ["hello", "my name is", "america is"]:
|
53 |
+
result = generate(model, tokenizer, prompt, max_tokens=200, device=device)
|
54 |
+
logger.info(f"\nPrompt: {prompt}")
|
55 |
+
logger.info(f"Generated: {result}")
|
56 |
+
logger.info("=" * 40)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|