File size: 2,194 Bytes
b57fe5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import logging
import os
from datetime import datetime

# Global variables for data
train_data = None
val_data = None


def setup_logging(log_dir="logs"):
    # Create logs directory if it doesn't exist
    os.makedirs(log_dir, exist_ok=True)

    # Create a timestamp for the log file
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"training_{timestamp}.log")

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(),  # Also print to console
        ],
    )

    logging.info(f"Logging to {log_file}")
    return logging.getLogger(__name__)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


def get_batch(split, block_size, batch_size):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y


def prepare_data(text, tokenizer):
    """Prepare train and validation data"""
    global train_data, val_data

    # Encode the text
    data = torch.tensor(tokenizer.encode(text), dtype=torch.long)

    # Split into train and validation sets
    n = int(0.9 * len(data))
    train_data = data[:n]
    val_data = data[n:]


def generate(model, tokenizer, prompt, max_tokens, device):
    model.eval()
    tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long)[None].to(device)
    block_size = model.config.block_size

    for _ in range(max_tokens):
        with torch.no_grad():
            logits, _ = model(tokens[:, -block_size:])
            logits = logits[:, -1, :]  # / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            tokens = torch.cat([tokens, next_token], dim=1)

    return tokenizer.decode(tokens[0].tolist())[len(prompt) :]