File size: 1,734 Bytes
0955f14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def load_checkpoint(filepath: str) -> dict:\n",
" \"\"\"\n",
" Load a checkpoint file.\n",
"\n",
" Args:\n",
" filepath (str): Path to the .ckpt file.\n",
"\n",
" Returns:\n",
" dict: Contents of the checkpoint file.\n",
" \"\"\"\n",
" checkpoint = torch.load(filepath, map_location=torch.device('cpu'))\n",
" return checkpoint\n",
"\n",
"checkpoint_path = 'ckpt.pt'\n",
"checkpoint_data = load_checkpoint(checkpoint_path)\n",
"\n",
"# Print the keys to understand what's inside\n",
"print(checkpoint_data.keys())\n",
"\n",
"# If you want to view specific information, access it using the keys\n",
"# For example, to view the model's state_dict\n",
"model_state = checkpoint_data.get('state_dict', None)\n",
"if model_state:\n",
" print(\"Model's state dict:\", model_state)\n",
"\n",
"# To view training information like current learning rate, iterations, etc.\n",
"training_info = checkpoint_data.get('training_info', None)\n",
"if training_info:\n",
" print(\"Training Info:\", training_info)\n",
"\n",
"# To view config, if it's stored in the checkpoint\n",
"config = checkpoint_data.get('config', None)\n",
"if config:\n",
" print(\"Configurations:\", config)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|