from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, TrainerCallback from datasets import load_dataset import torch import os import psutil import gc # Memory management and environment setup def cleanup_memory(): gc.collect() torch.mps.empty_cache() if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() # Set MPS memory limits and environment variables # Note: Changed watermark ratio to a more conservative value os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.7' # Changed from 0.8 os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5' # Added explicit low watermark os.environ['PYTORCH_MPS_ALLOCATOR_POLICY'] = 'garbage_collection_conservative' os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Memory monitoring def print_memory_stats(): process = psutil.Process() print(f"RAM Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB") if hasattr(torch.mps, 'current_allocated_memory'): print(f"MPS Memory allocated: {torch.mps.current_allocated_memory() / 1024 / 1024:.2f} MB") # Custom callback for memory monitoring class MemoryCallback(TrainerCallback): def __init__(self, print_memory_stats_fn): self.print_memory_stats_fn = print_memory_stats_fn def on_step_end(self, args, state, control, **kwargs): if state.global_step % 100 == 0: print(f"\nStep {state.global_step}:") self.print_memory_stats_fn() cleanup_memory() # Set device device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') print(f"Using device: {device}") # Load model and tokenizer model_name = "distilgpt2" model = AutoModelForCausalLM.from_pretrained( model_name, use_cache=False, torch_dtype=torch.float32 ) model.to(device) # Explicitly move model to device tokenizer = AutoTokenizer.from_pretrained(model_name) # Add pad token tokenizer.pad_token = tokenizer.eos_token # Load and filter dataset train_data = load_dataset("json", data_files={"train": "data.json"}) def filter_dataset(example): return len(example["prompt"]) + len(example["completion"]) <= 512 train_data = train_data.filter(filter_dataset) # Preprocess function def preprocess_function(examples): inputs = [prompt + tokenizer.eos_token + completion for prompt, completion in zip(examples["prompt"], examples["completion"])] model_inputs = tokenizer( inputs, max_length=256, truncation=True, padding="max_length" ) model_inputs["labels"] = model_inputs["input_ids"].copy() return model_inputs # Preprocess the dataset train_dataset = train_data["train"].map(preprocess_function, batched=True) # Training arguments training_args = TrainingArguments( output_dir="./results", num_train_epochs=15, per_device_train_batch_size=1, gradient_accumulation_steps=8, # Reduced from 32 logging_dir="./logs", fp16=False, eval_strategy="no", learning_rate=1e-5, # Reduced from 5e-5 save_steps=100, save_total_limit=2, gradient_checkpointing=True, optim="adamw_torch", dataloader_num_workers=0, dataloader_pin_memory=False, torch_compile=False, max_grad_norm=1.0, # Increased from 0.5 logging_steps=5, # More frequent logging max_steps=1000, warmup_steps=300, # Increased warmup steps weight_decay=0.2, # Increased from 0.01 logging_first_step=True, lr_scheduler_type="cosine_with_restarts", # Changed to cosine with restarts warmup_ratio=0.15, # Increased warmup ratio ) # Clear cache before training cleanup_memory() # Initialize trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, callbacks=[MemoryCallback(print_memory_stats)] ) # Monitor initial memory usage print("Initial memory usage:") print_memory_stats() # Training with error handling try: trainer.train() except Exception as e: print(f"Training error: {str(e)}") cleanup_memory() try: model.save_pretrained("./lockin_model_partial") tokenizer.save_pretrained("./lockin_model_partial") print("Saved partial progress") except: print("Could not save partial progress") raise e finally: cleanup_memory() # Save the complete model try: model.save_pretrained("./lockin_model") tokenizer.save_pretrained("./lockin_model") print("Model saved successfully") except Exception as e: print(f"Error saving model: {str(e)}") # Final cleanup cleanup_memory() print("\nFinal memory usage:") print_memory_stats()