Spaces:
Running
Running
import os | |
import re | |
from os.path import join | |
import torch | |
def get_latest(name, checkpoint_dir, extra_args=None): | |
if extra_args is None: | |
extra_args = dict() | |
files = os.listdir(join(checkpoint_dir, name)) | |
steps = torch.tensor([int(f.split("step=")[-1].split(".")[0]) for f in files]) | |
selected = files[steps.argmax()] | |
return dict( | |
chkpt_name=os.path.join(name, selected), | |
extra_args=extra_args) | |
DS_PARAM_REGEX = r'_forward_module\.(.+)' | |
def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None): | |
''' | |
Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching | |
in parameters which are improperly loaded by the DeepSpeed conversion utility. | |
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. | |
pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be | |
placed in the same directory as the DeepSpeed checkpoint directory with the same name but | |
a .pt extension. | |
Returns: path to the converted checkpoint. | |
''' | |
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict | |
if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)): | |
raise ValueError( | |
'args.ckpt_dir should point to the checkpoint directory' | |
' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").' | |
) | |
# Convert state dict to PyTorch format | |
if not pl_ckpt_path: | |
pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt | |
if not os.path.exists(pl_ckpt_path): | |
convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path) | |
# Patch in missing parameters that failed to be converted by DeepSpeed utility | |
pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path) | |
torch.save(pl_ckpt, pl_ckpt_path) | |
return pl_ckpt_path | |
def get_optim_files(checkpoint_dir): | |
files = sorted([f for f in os.listdir(checkpoint_dir) if "optim" in f]) | |
return [join(checkpoint_dir, f) for f in files] | |
def get_model_state_file(checkpoint_dir, zero_stage): | |
f = [f for f in os.listdir(checkpoint_dir) if "model_states" in f][0] | |
return join(checkpoint_dir, f) | |
def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str): | |
''' | |
Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint | |
into the fp32 state dict. | |
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. | |
fp32_ckpt_path: Path to the reconstructed | |
''' | |
from pytorch_lightning.utilities.deepspeed import ds_checkpoint_dir | |
# This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict | |
checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path) | |
optim_files = get_optim_files(checkpoint_dir) | |
optim_state = torch.load(optim_files[0], map_location='cpu') | |
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] | |
deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage) | |
# Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt | |
ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu') | |
ds_sd = ds_ckpt['module'] | |
fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu') | |
fp32_sd = fp32_ckpt['state_dict'] | |
for k, v in ds_sd.items(): | |
try: | |
match = re.match(DS_PARAM_REGEX, k) | |
param_name = match.group(1) | |
except: | |
print(f'Failed to extract parameter from DeepSpeed key {k}') | |
continue | |
v = v.to(torch.float32) | |
if param_name not in fp32_sd: | |
print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd') | |
fp32_sd[param_name] = v | |
else: | |
assert torch.allclose(v, fp32_sd[param_name].to(torch.float32), atol=1e-2) | |
return fp32_ckpt | |
def get_version_and_step(f, i): | |
step = f.split("step=")[-1].split(".")[0] | |
if "-v" in step: | |
[step, version] = step.split("-v") | |
else: | |
step, version = step, 0 | |
return int(version), int(step), i | |
def get_latest_ds(name, extra_args=None): | |
if extra_args is None: | |
extra_args = dict() | |
files = os.listdir(f"../checkpoints/{name}") | |
latest = sorted([get_version_and_step(f, i) for i, f in enumerate(files)], reverse=True)[0] | |
selected = files[latest[-1]] | |
# print(f"Selecting file: {selected}") | |
ds_chkpt = join(name, selected) | |
reg_chkpt = join(name + "_fp32", selected) | |
reg_chkpt_path = join("../checkpoints", reg_chkpt) | |
if not os.path.exists(reg_chkpt_path): | |
os.makedirs(os.path.dirname(reg_chkpt_path), exist_ok=True) | |
print(f"Checkpoint {reg_chkpt} does not exist, converting from deepspeed") | |
convert_deepspeed_checkpoint(join("../checkpoints", ds_chkpt), reg_chkpt_path) | |
return dict( | |
chkpt_name=reg_chkpt, | |
extra_args=extra_args) | |
def get_all_models_in_dir(name, checkpoint_dir, extra_args=None): | |
ret = {} | |
for model_dir in os.listdir(join(checkpoint_dir, name)): | |
full_name = f"{name}/{model_dir}/train" | |
# print(f'"{full_name}",') | |
ret[full_name] = get_latest(full_name, checkpoint_dir, extra_args) | |
return ret | |
def saved_model_dict(checkpoint_dir): | |
model_info = { | |
**get_all_models_in_dir( | |
"9-5-23-mixed", | |
checkpoint_dir, | |
extra_args=dict( | |
mixup_weight=0.0, | |
sim_use_cls=False, | |
audio_pool_width=1, | |
memory_buffer_size=0, | |
loss_leak=0.0) | |
), | |
**get_all_models_in_dir( | |
"1-23-24-rebuttal-heads", | |
checkpoint_dir, | |
extra_args=dict( | |
loss_leak=0.0) | |
), | |
**get_all_models_in_dir( | |
"11-8-23", | |
checkpoint_dir, | |
extra_args=dict(loss_leak=0.0)), | |
**get_all_models_in_dir( | |
"10-30-23-3", | |
checkpoint_dir, | |
extra_args=dict(loss_leak=0.0)), | |
"davenet": dict( | |
chkpt_name=None, | |
extra_args=dict( | |
audio_blur=1, | |
image_model_type="davenet", | |
image_aligner_type=None, | |
audio_model_type="davenet", | |
audio_aligner_type=None, | |
audio_input="davenet_spec", | |
use_cached_embs=False, | |
dropout=False, | |
sim_agg_heads=1, | |
nonneg_sim=False, | |
audio_lora=False, | |
image_lora=False, | |
norm_vectors=False, | |
), | |
data_args=dict( | |
use_cached_embs=False, | |
use_davenet_spec=True, | |
override_target_length=20, | |
audio_model_type="davenet", | |
), | |
), | |
"cavmae": dict( | |
chkpt_name=None, | |
extra_args=dict( | |
audio_blur=1, | |
image_model_type="cavmae", | |
image_aligner_type=None, | |
audio_model_type="cavmae", | |
audio_aligner_type=None, | |
audio_input="spec", | |
use_cached_embs=False, | |
sim_agg_heads=1, | |
dropout=False, | |
nonneg_sim=False, | |
audio_lora=False, | |
image_lora=False, | |
norm_vectors=False, | |
learn_audio_cls=False, | |
sim_agg_type="cavmae", | |
), | |
data_args=dict( | |
use_cached_embs=False, | |
use_davenet_spec=True, | |
audio_model_type="cavmae", | |
override_target_length=10, | |
), | |
), | |
"imagebind": dict( | |
chkpt_name=None, | |
extra_args=dict( | |
audio_blur=1, | |
image_model_type="imagebind", | |
image_aligner_type=None, | |
audio_model_type="imagebind", | |
audio_aligner_type=None, | |
audio_input="spec", | |
use_cached_embs=False, | |
sim_agg_heads=1, | |
dropout=False, | |
nonneg_sim=False, | |
audio_lora=False, | |
image_lora=False, | |
norm_vectors=False, | |
learn_audio_cls=False, | |
sim_agg_type="imagebind", | |
), | |
data_args=dict( | |
use_cached_embs=False, | |
use_davenet_spec=True, | |
audio_model_type="imagebind", | |
override_target_length=10, | |
), | |
), | |
} | |
model_info["denseav_language"] = model_info["10-30-23-3/places_base/train"] | |
model_info["denseav_sound"] = model_info["11-8-23/hubert_1h_asf_cls_full_image_train_small_lr/train"] | |
model_info["denseav_2head"] = model_info["1-23-24-rebuttal-heads/mixed-2h/train"] | |
return model_info | |