File size: 8,683 Bytes
9e9510b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 1 of reorganization complete.\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"\n",
"def reorganize_checkpoints_step1(root_dir, layers):\n",
" for layer in layers:\n",
" layer_path = os.path.join(root_dir, layer)\n",
" if not os.path.exists(layer_path):\n",
" print(f\"Warning: {layer_path} does not exist. Skipping.\")\n",
" continue\n",
"\n",
" # Create the new results directory\n",
" results_dir = f\"{layer}_checkpoints\"\n",
" results_path = os.path.join(root_dir, results_dir)\n",
" os.makedirs(results_path, exist_ok=True)\n",
"\n",
"\n",
" # Iterate through trainer directories\n",
" for trainer in os.listdir(layer_path):\n",
" if trainer.startswith('trainer_'):\n",
" trainer_path = os.path.join(layer_path, trainer)\n",
" config_path = os.path.join(trainer_path, 'config.json')\n",
" checkpoints_path = os.path.join(trainer_path, 'checkpoints')\n",
"\n",
" # Create trainer directory in results\n",
" trainer_results_path = os.path.join(results_path, trainer)\n",
" os.makedirs(trainer_results_path, exist_ok=True)\n",
"\n",
" # Copy config.json if it exists\n",
" if os.path.exists(config_path):\n",
" shutil.copy2(config_path, trainer_results_path)\n",
" else:\n",
" print(f\"Warning: config.json not found in {trainer_path}\")\n",
"\n",
" # Move checkpoints directory if it exists\n",
" if os.path.exists(checkpoints_path):\n",
" shutil.move(checkpoints_path, trainer_results_path)\n",
" else:\n",
" print(f\"Warning: checkpoints directory not found in {trainer_path}\")\n",
"\n",
" print(\"Step 1 of reorganization complete.\")\n",
"\n",
"\n",
"\n",
"root_directory = \"/workspace/sae_eval/dictionary_learning/dictionaries/gemma-2-2b/gemma-2-2b_sweep_topk_ctx128_ef8_0824\"\n",
"layers_to_process = [\"resid_post_layer_3\", \"resid_post_layer_7\", \"resid_post_layer_11\", \"resid_post_layer_15\", \"resid_post_layer_19\"]\n",
"reorganize_checkpoints_step1(root_directory, layers_to_process)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 2 of reorganization complete.\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"import re\n",
"import json\n",
"\n",
"def reorganize_checkpoints_step2(root_dir, checkpoint_dirs):\n",
" for checkpoint_dir in checkpoint_dirs:\n",
" checkpoint_path = os.path.join(root_dir, checkpoint_dir)\n",
" if not os.path.exists(checkpoint_path):\n",
" print(f\"Warning: {checkpoint_path} does not exist. Skipping.\")\n",
" continue\n",
"\n",
" # Iterate through trainer directories\n",
" for trainer in os.listdir(checkpoint_path):\n",
" if trainer.startswith('trainer_'):\n",
" trainer_path = os.path.join(checkpoint_path, trainer)\n",
" config_path = os.path.join(trainer_path, 'config.json')\n",
" checkpoints_path = os.path.join(trainer_path, 'checkpoints')\n",
"\n",
" if not os.path.exists(checkpoints_path):\n",
" print(f\"Warning: checkpoints directory not found in {trainer_path}\")\n",
" continue\n",
"\n",
" # Process each checkpoint\n",
" for checkpoint in os.listdir(checkpoints_path):\n",
" match = re.match(r'ae_(\\d+)\\.pt', checkpoint)\n",
" if match:\n",
" step = match.group(1)\n",
" new_checkpoint_dir = os.path.join(checkpoint_path, f'{trainer}_step_{step}')\n",
" os.makedirs(new_checkpoint_dir, exist_ok=True)\n",
"\n",
" # Copy config.json\n",
" if os.path.exists(config_path):\n",
" with open(config_path, 'r') as f:\n",
" config = json.load(f)\n",
" config['trainer']['steps'] = step\n",
" new_config_path = os.path.join(new_checkpoint_dir, 'config.json')\n",
" with open(new_config_path, 'w') as f:\n",
" json.dump(config, f, indent=2)\n",
" else:\n",
" raise Exception(f\"Config.json not found for {trainer}\")\n",
" print(f\"Warning: config.json not found for {trainer}\")\n",
"\n",
" # Move and rename checkpoint file\n",
" old_checkpoint_path = os.path.join(checkpoints_path, checkpoint)\n",
" new_checkpoint_path = os.path.join(new_checkpoint_dir, 'ae.pt')\n",
" shutil.move(old_checkpoint_path, new_checkpoint_path)\n",
"\n",
" # Remove the original checkpoints directory if it's empty\n",
" if not os.listdir(checkpoints_path):\n",
" os.rmdir(checkpoints_path)\n",
" else:\n",
" raise Exception(f\"Checkpoints directory {checkpoints_path} is not empty.\")\n",
"\n",
" # Remove the config.json file\n",
" if os.path.exists(config_path):\n",
" os.remove(config_path)\n",
" else:\n",
" print(f\"Warning: config.json not found for {trainer}\")\n",
"\n",
" # Remove the trainer directory\n",
" if not os.listdir(trainer_path):\n",
" os.rmdir(trainer_path)\n",
" else:\n",
" raise Exception(f\"Trainer directory {trainer_path} is not empty.\")\n",
"\n",
" print(\"Step 2 of reorganization complete.\")\n",
"\n",
"checkpoint_dirs_to_process = []\n",
"for layer in layers_to_process:\n",
" checkpoint_dirs_to_process.append(f\"{layer}_checkpoints\")\n",
"\n",
"reorganize_checkpoints_step2(root_directory, checkpoint_dirs_to_process)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def compare_pytorch_models(file1, file2):\n",
" # Load the models\n",
" model1 = torch.load(file1, map_location=torch.device('cpu'))\n",
" model2 = torch.load(file2, map_location=torch.device('cpu'))\n",
" \n",
" # If the loaded objects are not dictionaries, assume they are the state dictionaries\n",
" if not isinstance(model1, dict):\n",
" model1 = model1.state_dict()\n",
" if not isinstance(model2, dict):\n",
" model2 = model2.state_dict()\n",
" \n",
" # Check if the models have the same keys\n",
" assert set(model1.keys()) == set(model2.keys()), \"Models have different keys\"\n",
" \n",
" # Compare each parameter\n",
" for key in model1.keys():\n",
" print(key)\n",
" assert torch.allclose(model1[key], model2[key], atol=1e-7), f\"Mismatch in parameter {key}\"\n",
" \n",
" print(\"Models are identical within the specified tolerance.\")\n",
"\n",
"# Usage example (you can run this in your Jupyter notebook):\n",
"compare_pytorch_models('ae_4882.pt', 'ae_4882_converted.pt')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|