import torch import logging import os from datetime import datetime # Global variables for data train_data = None val_data = None def setup_logging(log_dir="logs"): # Create logs directory if it doesn't exist os.makedirs(log_dir, exist_ok=True) # Create a timestamp for the log file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = os.path.join(log_dir, f"training_{timestamp}.log") # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler(log_file), logging.StreamHandler(), # Also print to console ], ) logging.info(f"Logging to {log_file}") return logging.getLogger(__name__) def count_parameters(model): return sum(p.numel() for p in model.parameters()) def get_batch(split, block_size, batch_size): data = train_data if split == "train" else val_data ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([data[i : i + block_size] for i in ix]) y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]) return x, y def prepare_data(text, tokenizer): """Prepare train and validation data""" global train_data, val_data # Encode the text data = torch.tensor(tokenizer.encode(text), dtype=torch.long) # Split into train and validation sets n = int(0.9 * len(data)) train_data = data[:n] val_data = data[n:] def generate(model, tokenizer, prompt, max_tokens, device): model.eval() tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long)[None].to(device) block_size = model.config.block_size for _ in range(max_tokens): with torch.no_grad(): logits, _ = model(tokens[:, -block_size:]) logits = logits[:, -1, :] # / temperature probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) tokens = torch.cat([tokens, next_token], dim=1) return tokenizer.decode(tokens[0].tolist())[len(prompt) :]