Spaces:
Sleeping
Sleeping
File size: 2,194 Bytes
b57fe5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
import logging
import os
from datetime import datetime
# Global variables for data
train_data = None
val_data = None
def setup_logging(log_dir="logs"):
# Create logs directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)
# Create a timestamp for the log file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"training_{timestamp}.log")
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(), # Also print to console
],
)
logging.info(f"Logging to {log_file}")
return logging.getLogger(__name__)
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def get_batch(split, block_size, batch_size):
data = train_data if split == "train" else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i : i + block_size] for i in ix])
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
return x, y
def prepare_data(text, tokenizer):
"""Prepare train and validation data"""
global train_data, val_data
# Encode the text
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
# Split into train and validation sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
def generate(model, tokenizer, prompt, max_tokens, device):
model.eval()
tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long)[None].to(device)
block_size = model.config.block_size
for _ in range(max_tokens):
with torch.no_grad():
logits, _ = model(tokens[:, -block_size:])
logits = logits[:, -1, :] # / temperature
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
tokens = torch.cat([tokens, next_token], dim=1)
return tokenizer.decode(tokens[0].tolist())[len(prompt) :]
|