Upload folder using huggingface_hub
Browse files- config.json +3 -3
- generation_config.json +4 -1
- global_step1864/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt +3 -0
- global_step1864/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt +3 -0
- global_step1864/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt +3 -0
- global_step1864/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt +3 -0
- global_step1864/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt +3 -0
- global_step1864/zero_pp_rank_0_mp_rank_00_model_states.pt +3 -0
- global_step1864/zero_pp_rank_1_mp_rank_00_model_states.pt +3 -0
- global_step1864/zero_pp_rank_2_mp_rank_00_model_states.pt +3 -0
- global_step1864/zero_pp_rank_3_mp_rank_00_model_states.pt +3 -0
- global_step1864/zero_pp_rank_4_mp_rank_00_model_states.pt +3 -0
- latest +1 -0
- model-00001-of-00002.safetensors +2 -2
- model-00002-of-00002.safetensors +2 -2
- model.safetensors.index.json +1 -1
- rng_state_0.pth +3 -0
- rng_state_1.pth +3 -0
- rng_state_2.pth +3 -0
- rng_state_3.pth +3 -0
- rng_state_4.pth +3 -0
- scheduler.pt +1 -1
- special_tokens_map.json +1 -1
- tokenizer.json +2 -2
- tokenizer_config.json +1 -9
- trainer_state.json +0 -0
- training_args.bin +2 -2
- zero_to_fp32.py +24 -12
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"Gemma2ForCausalLM"
|
5 |
],
|
@@ -26,8 +26,8 @@
|
|
26 |
"rms_norm_eps": 1e-06,
|
27 |
"rope_theta": 10000.0,
|
28 |
"sliding_window": 4096,
|
29 |
-
"torch_dtype": "
|
30 |
"transformers_version": "4.46.1",
|
31 |
"use_cache": false,
|
32 |
-
"vocab_size":
|
33 |
}
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "/mloscratch/homes/meditron-team/models/gemma2/checkpoint-1854",
|
3 |
"architectures": [
|
4 |
"Gemma2ForCausalLM"
|
5 |
],
|
|
|
26 |
"rms_norm_eps": 1e-06,
|
27 |
"rope_theta": 10000.0,
|
28 |
"sliding_window": 4096,
|
29 |
+
"torch_dtype": "bfloat16",
|
30 |
"transformers_version": "4.46.1",
|
31 |
"use_cache": false,
|
32 |
+
"vocab_size": 256000
|
33 |
}
|
generation_config.json
CHANGED
@@ -3,7 +3,10 @@
|
|
3 |
"bos_token_id": 2,
|
4 |
"cache_implementation": "hybrid",
|
5 |
"do_sample": true,
|
6 |
-
"eos_token_id":
|
|
|
|
|
|
|
7 |
"pad_token_id": 0,
|
8 |
"transformers_version": "4.46.1"
|
9 |
}
|
|
|
3 |
"bos_token_id": 2,
|
4 |
"cache_implementation": "hybrid",
|
5 |
"do_sample": true,
|
6 |
+
"eos_token_id": [
|
7 |
+
1,
|
8 |
+
107
|
9 |
+
],
|
10 |
"pad_token_id": 0,
|
11 |
"transformers_version": "4.46.1"
|
12 |
}
|
global_step1864/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fe79999046943bc62b10be111a20213d7dd658397555cb35e506be9fb64007f
|
3 |
+
size 6274425424
|
global_step1864/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d7eb2bf4e3cfbe4c3fc2ecccbfadf2dffd996be7b15252b1c98b6d7c466db91
|
3 |
+
size 6274425424
|
global_step1864/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46f3cdfacbdd14d8866ad32de3865c1a23bf807e57a2716aa73d52acf632690a
|
3 |
+
size 6274425424
|
global_step1864/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1712d66142c471a13d90c76a44581d465eb565d20ae03f03ee32261cadb27f3
|
3 |
+
size 6274425424
|
global_step1864/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93d29c222a98b329ef0f25d7fb9f88f186d4794e8af143c68aa2685bf36ca571
|
3 |
+
size 6274425424
|
global_step1864/zero_pp_rank_0_mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f36d2a926915b07449ee765d24839a6a812c21dd81032121b495395438ab2940
|
3 |
+
size 149262
|
global_step1864/zero_pp_rank_1_mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:972ef0e8fa5a3203d6e69152dd53be3bbba4ea528c08dae57d3d12fd99980da5
|
3 |
+
size 149198
|
global_step1864/zero_pp_rank_2_mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d70591c1d2e9f248828fc53d9d32bb928fdcc7edf839bcd59f2d4865f11209df
|
3 |
+
size 149198
|
global_step1864/zero_pp_rank_3_mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4445479dc07c9dc60723c434fc8d384c5adf1b3f400682ffdcf59f2120a7b2d3
|
3 |
+
size 149198
|
global_step1864/zero_pp_rank_4_mp_rank_00_model_states.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7e3b28299c6f82810f8b69c924ae5fcf847bee9676b35ad91f82e5c1361151d
|
3 |
+
size 149198
|
latest
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
global_step1864
|
model-00001-of-00002.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:629b6c82afc7c798a78fd7940404debc3580ce7a66cd870bcdb754011c39ab76
|
3 |
+
size 4988025760
|
model-00002-of-00002.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9591cd00015a0a0947b81e3c446d35e6b65b363e078507b220def73e3b3b78f0
|
3 |
+
size 240691728
|
model.safetensors.index.json
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
{
|
2 |
"metadata": {
|
3 |
-
"total_size":
|
4 |
},
|
5 |
"weight_map": {
|
6 |
"model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
|
|
1 |
{
|
2 |
"metadata": {
|
3 |
+
"total_size": 5228683776
|
4 |
},
|
5 |
"weight_map": {
|
6 |
"model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
rng_state_0.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:adb66d6753eab70f12094dcbbbde2ddd7149e12f3d3a3efcade2fd44674e8f6a
|
3 |
+
size 15280
|
rng_state_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94e1312d93d92412eeeb9726ed060a5b8e22cdfd40b632e62c1e271444ff254b
|
3 |
+
size 15280
|
rng_state_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd190e746f60f7244032d517fc7a16b7d52753c597470122870679705c44ceff
|
3 |
+
size 15280
|
rng_state_3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66755f27fa20606f38b75c0784f4331e15225df1e95a25c242181b8e488576aa
|
3 |
+
size 15280
|
rng_state_4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c0827ee9e1de5036a02092613d8fb8ba2432aa3f718549b0e280b3eb0fa7e7d
|
3 |
+
size 15280
|
scheduler.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1064
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c668c049bd7d0c7099cdbf10c29b8634a7fc3a06825f6bd5b09825496ac371a4
|
3 |
size 1064
|
special_tokens_map.json
CHANGED
@@ -18,7 +18,7 @@
|
|
18 |
"single_word": false
|
19 |
},
|
20 |
"pad_token": {
|
21 |
-
"content": "
|
22 |
"lstrip": false,
|
23 |
"normalized": false,
|
24 |
"rstrip": false,
|
|
|
18 |
"single_word": false
|
19 |
},
|
20 |
"pad_token": {
|
21 |
+
"content": "<pad>",
|
22 |
"lstrip": false,
|
23 |
"normalized": false,
|
24 |
"rstrip": false,
|
tokenizer.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f7eee611703c5ce5d1eee32d9cdcfe465647b8aff0c1dfb3bed7ad7dbb05060
|
3 |
+
size 34362873
|
tokenizer_config.json
CHANGED
@@ -1993,14 +1993,6 @@
|
|
1993 |
"rstrip": false,
|
1994 |
"single_word": false,
|
1995 |
"special": false
|
1996 |
-
},
|
1997 |
-
"256000": {
|
1998 |
-
"content": "<|end_of_text|>",
|
1999 |
-
"lstrip": false,
|
2000 |
-
"normalized": false,
|
2001 |
-
"rstrip": false,
|
2002 |
-
"single_word": false,
|
2003 |
-
"special": true
|
2004 |
}
|
2005 |
},
|
2006 |
"additional_special_tokens": [
|
@@ -2012,7 +2004,7 @@
|
|
2012 |
"clean_up_tokenization_spaces": false,
|
2013 |
"eos_token": "<eos>",
|
2014 |
"model_max_length": 1000000000000000019884624838656,
|
2015 |
-
"pad_token": "
|
2016 |
"sp_model_kwargs": {},
|
2017 |
"spaces_between_special_tokens": false,
|
2018 |
"tokenizer_class": "GemmaTokenizer",
|
|
|
1993 |
"rstrip": false,
|
1994 |
"single_word": false,
|
1995 |
"special": false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1996 |
}
|
1997 |
},
|
1998 |
"additional_special_tokens": [
|
|
|
2004 |
"clean_up_tokenization_spaces": false,
|
2005 |
"eos_token": "<eos>",
|
2006 |
"model_max_length": 1000000000000000019884624838656,
|
2007 |
+
"pad_token": "<pad>",
|
2008 |
"sp_model_kwargs": {},
|
2009 |
"spaces_between_special_tokens": false,
|
2010 |
"tokenizer_class": "GemmaTokenizer",
|
trainer_state.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
training_args.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cf206cf17111aa25faa473fbc4816437a914bf925de5dbca7f5a97a92b598ce
|
3 |
+
size 8568
|
zero_to_fp32.py
CHANGED
@@ -191,7 +191,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
|
|
191 |
return zero_stage, world_size, fp32_flat_groups
|
192 |
|
193 |
|
194 |
-
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
|
195 |
"""
|
196 |
Returns fp32 state_dict reconstructed from ds checkpoint
|
197 |
|
@@ -211,9 +211,11 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
|
|
211 |
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
212 |
|
213 |
if zero_stage <= 2:
|
214 |
-
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states
|
|
|
215 |
elif zero_stage == 3:
|
216 |
-
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states
|
|
|
217 |
|
218 |
|
219 |
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
@@ -326,7 +328,8 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|
326 |
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
327 |
|
328 |
|
329 |
-
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states
|
|
|
330 |
state_dict = OrderedDict()
|
331 |
|
332 |
# buffers
|
@@ -335,7 +338,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer
|
|
335 |
if debug:
|
336 |
print(f"added {len(buffers)} buffers")
|
337 |
|
338 |
-
|
|
|
339 |
|
340 |
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
341 |
|
@@ -444,7 +448,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|
444 |
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
445 |
|
446 |
|
447 |
-
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states
|
|
|
448 |
state_dict = OrderedDict()
|
449 |
|
450 |
# buffers
|
@@ -453,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
|
|
453 |
if debug:
|
454 |
print(f"added {len(buffers)} buffers")
|
455 |
|
456 |
-
|
|
|
457 |
|
458 |
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
459 |
|
@@ -465,7 +471,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
|
|
465 |
return state_dict
|
466 |
|
467 |
|
468 |
-
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
|
469 |
"""
|
470 |
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
471 |
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
@@ -474,6 +480,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
|
|
474 |
Args:
|
475 |
- ``checkpoint_dir``: path to the desired checkpoint folder
|
476 |
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
|
|
477 |
|
478 |
Returns:
|
479 |
- pytorch ``state_dict``
|
@@ -511,10 +518,10 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
|
|
511 |
if not os.path.isdir(ds_checkpoint_dir):
|
512 |
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
513 |
|
514 |
-
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
|
515 |
|
516 |
|
517 |
-
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
|
518 |
"""
|
519 |
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
520 |
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
@@ -523,9 +530,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=
|
|
523 |
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
524 |
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
525 |
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
|
|
526 |
"""
|
527 |
|
528 |
-
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
529 |
print(f"Saving fp32 state dict to {output_file}")
|
530 |
torch.save(state_dict, output_file)
|
531 |
|
@@ -584,9 +592,13 @@ if __name__ == "__main__":
|
|
584 |
type=str,
|
585 |
default=None,
|
586 |
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
|
|
587 |
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
588 |
args = parser.parse_args()
|
589 |
|
590 |
debug = args.debug
|
591 |
|
592 |
-
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
|
|
|
|
|
|
|
|
191 |
return zero_stage, world_size, fp32_flat_groups
|
192 |
|
193 |
|
194 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
|
195 |
"""
|
196 |
Returns fp32 state_dict reconstructed from ds checkpoint
|
197 |
|
|
|
211 |
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
212 |
|
213 |
if zero_stage <= 2:
|
214 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
215 |
+
exclude_frozen_parameters)
|
216 |
elif zero_stage == 3:
|
217 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
218 |
+
exclude_frozen_parameters)
|
219 |
|
220 |
|
221 |
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
|
|
328 |
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
329 |
|
330 |
|
331 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
332 |
+
exclude_frozen_parameters):
|
333 |
state_dict = OrderedDict()
|
334 |
|
335 |
# buffers
|
|
|
338 |
if debug:
|
339 |
print(f"added {len(buffers)} buffers")
|
340 |
|
341 |
+
if not exclude_frozen_parameters:
|
342 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
343 |
|
344 |
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
345 |
|
|
|
448 |
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
449 |
|
450 |
|
451 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
452 |
+
exclude_frozen_parameters):
|
453 |
state_dict = OrderedDict()
|
454 |
|
455 |
# buffers
|
|
|
458 |
if debug:
|
459 |
print(f"added {len(buffers)} buffers")
|
460 |
|
461 |
+
if not exclude_frozen_parameters:
|
462 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
463 |
|
464 |
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
465 |
|
|
|
471 |
return state_dict
|
472 |
|
473 |
|
474 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
|
475 |
"""
|
476 |
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
477 |
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
|
|
480 |
Args:
|
481 |
- ``checkpoint_dir``: path to the desired checkpoint folder
|
482 |
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
483 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
484 |
|
485 |
Returns:
|
486 |
- pytorch ``state_dict``
|
|
|
518 |
if not os.path.isdir(ds_checkpoint_dir):
|
519 |
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
520 |
|
521 |
+
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
522 |
|
523 |
|
524 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
|
525 |
"""
|
526 |
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
527 |
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
|
|
530 |
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
531 |
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
532 |
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
533 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
534 |
"""
|
535 |
|
536 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
|
537 |
print(f"Saving fp32 state dict to {output_file}")
|
538 |
torch.save(state_dict, output_file)
|
539 |
|
|
|
592 |
type=str,
|
593 |
default=None,
|
594 |
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
595 |
+
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
596 |
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
597 |
args = parser.parse_args()
|
598 |
|
599 |
debug = args.debug
|
600 |
|
601 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
|
602 |
+
args.output_file,
|
603 |
+
tag=args.tag,
|
604 |
+
exclude_frozen_parameters=args.exclude_frozen_parameters)
|