Spaces:
Sleeping
Sleeping
import torch | |
import logging | |
import os | |
from datetime import datetime | |
# Global variables for data | |
train_data = None | |
val_data = None | |
def setup_logging(log_dir="logs"): | |
# Create logs directory if it doesn't exist | |
os.makedirs(log_dir, exist_ok=True) | |
# Create a timestamp for the log file | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
log_file = os.path.join(log_dir, f"training_{timestamp}.log") | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=[ | |
logging.FileHandler(log_file), | |
logging.StreamHandler(), # Also print to console | |
], | |
) | |
logging.info(f"Logging to {log_file}") | |
return logging.getLogger(__name__) | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters()) | |
def get_batch(split, block_size, batch_size): | |
data = train_data if split == "train" else val_data | |
ix = torch.randint(len(data) - block_size, (batch_size,)) | |
x = torch.stack([data[i : i + block_size] for i in ix]) | |
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]) | |
return x, y | |
def prepare_data(text, tokenizer): | |
"""Prepare train and validation data""" | |
global train_data, val_data | |
# Encode the text | |
data = torch.tensor(tokenizer.encode(text), dtype=torch.long) | |
# Split into train and validation sets | |
n = int(0.9 * len(data)) | |
train_data = data[:n] | |
val_data = data[n:] | |
def generate(model, tokenizer, prompt, max_tokens, device): | |
model.eval() | |
tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long)[None].to(device) | |
block_size = model.config.block_size | |
for _ in range(max_tokens): | |
with torch.no_grad(): | |
logits, _ = model(tokens[:, -block_size:]) | |
logits = logits[:, -1, :] # / temperature | |
probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
tokens = torch.cat([tokens, next_token], dim=1) | |
return tokenizer.decode(tokens[0].tolist())[len(prompt) :] | |