Spaces:
Sleeping
Sleeping
import torch | |
from config.model_config import ModelConfig | |
from src.data.tokenizer import CharacterTokenizer | |
from src.model.gpt import GPTModel | |
from src.training.trainer import train | |
from src.utils.helpers import generate, setup_logging, prepare_data | |
def main(): | |
# Setup logging | |
logger = setup_logging() | |
# Load config | |
config = ModelConfig() | |
# Setup device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Load data | |
with open(config.data_path) as f: | |
text = f.read() | |
tokenizer = CharacterTokenizer(text) | |
# Prepare data | |
prepare_data(text, tokenizer) | |
# Create model | |
model = GPTModel(config, tokenizer.vocab_size) | |
model = model.to(device) | |
# Setup optimizer | |
optimizer = torch.optim.AdamW( | |
model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay | |
) | |
# Train | |
train( | |
model=model, | |
optimizer=optimizer, | |
max_iters=config.max_iters, | |
eval_interval=config.eval_interval, | |
eval_iters=config.eval_iters, | |
block_size=config.block_size, | |
batch_size=config.batch_size, | |
device=device, | |
checkpoint_path=config.checkpoint_path, | |
) | |
# Generate samples | |
model = torch.load(config.checkpoint_path, map_location=device) | |
for prompt in ["hello", "my name is", "america is"]: | |
result = generate(model, tokenizer, prompt, max_tokens=200, device=device) | |
logger.info(f"\nPrompt: {prompt}") | |
logger.info(f"Generated: {result}") | |
logger.info("=" * 40) | |
if __name__ == "__main__": | |
main() | |