DenseAV-Lowell / DenseAV /denseav /saved_models.py
lorocksUMD's picture
Upload 32 files
e6d4b46 verified
import os
import re
from os.path import join
import torch
def get_latest(name, checkpoint_dir, extra_args=None):
if extra_args is None:
extra_args = dict()
files = os.listdir(join(checkpoint_dir, name))
steps = torch.tensor([int(f.split("step=")[-1].split(".")[0]) for f in files])
selected = files[steps.argmax()]
return dict(
chkpt_name=os.path.join(name, selected),
extra_args=extra_args)
DS_PARAM_REGEX = r'_forward_module\.(.+)'
def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
'''
Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
in parameters which are improperly loaded by the DeepSpeed conversion utility.
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
placed in the same directory as the DeepSpeed checkpoint directory with the same name but
a .pt extension.
Returns: path to the converted checkpoint.
'''
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
raise ValueError(
'args.ckpt_dir should point to the checkpoint directory'
' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
)
# Convert state dict to PyTorch format
if not pl_ckpt_path:
pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt
if not os.path.exists(pl_ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)
# Patch in missing parameters that failed to be converted by DeepSpeed utility
pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
torch.save(pl_ckpt, pl_ckpt_path)
return pl_ckpt_path
def get_optim_files(checkpoint_dir):
files = sorted([f for f in os.listdir(checkpoint_dir) if "optim" in f])
return [join(checkpoint_dir, f) for f in files]
def get_model_state_file(checkpoint_dir, zero_stage):
f = [f for f in os.listdir(checkpoint_dir) if "model_states" in f][0]
return join(checkpoint_dir, f)
def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
'''
Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
into the fp32 state dict.
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
fp32_ckpt_path: Path to the reconstructed
'''
from pytorch_lightning.utilities.deepspeed import ds_checkpoint_dir
# This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
optim_files = get_optim_files(checkpoint_dir)
optim_state = torch.load(optim_files[0], map_location='cpu')
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)
# Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
ds_sd = ds_ckpt['module']
fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
fp32_sd = fp32_ckpt['state_dict']
for k, v in ds_sd.items():
try:
match = re.match(DS_PARAM_REGEX, k)
param_name = match.group(1)
except:
print(f'Failed to extract parameter from DeepSpeed key {k}')
continue
v = v.to(torch.float32)
if param_name not in fp32_sd:
print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
fp32_sd[param_name] = v
else:
assert torch.allclose(v, fp32_sd[param_name].to(torch.float32), atol=1e-2)
return fp32_ckpt
def get_version_and_step(f, i):
step = f.split("step=")[-1].split(".")[0]
if "-v" in step:
[step, version] = step.split("-v")
else:
step, version = step, 0
return int(version), int(step), i
def get_latest_ds(name, extra_args=None):
if extra_args is None:
extra_args = dict()
files = os.listdir(f"../checkpoints/{name}")
latest = sorted([get_version_and_step(f, i) for i, f in enumerate(files)], reverse=True)[0]
selected = files[latest[-1]]
# print(f"Selecting file: {selected}")
ds_chkpt = join(name, selected)
reg_chkpt = join(name + "_fp32", selected)
reg_chkpt_path = join("../checkpoints", reg_chkpt)
if not os.path.exists(reg_chkpt_path):
os.makedirs(os.path.dirname(reg_chkpt_path), exist_ok=True)
print(f"Checkpoint {reg_chkpt} does not exist, converting from deepspeed")
convert_deepspeed_checkpoint(join("../checkpoints", ds_chkpt), reg_chkpt_path)
return dict(
chkpt_name=reg_chkpt,
extra_args=extra_args)
def get_all_models_in_dir(name, checkpoint_dir, extra_args=None):
ret = {}
for model_dir in os.listdir(join(checkpoint_dir, name)):
full_name = f"{name}/{model_dir}/train"
# print(f'"{full_name}",')
ret[full_name] = get_latest(full_name, checkpoint_dir, extra_args)
return ret
def saved_model_dict(checkpoint_dir):
model_info = {
**get_all_models_in_dir(
"9-5-23-mixed",
checkpoint_dir,
extra_args=dict(
mixup_weight=0.0,
sim_use_cls=False,
audio_pool_width=1,
memory_buffer_size=0,
loss_leak=0.0)
),
**get_all_models_in_dir(
"1-23-24-rebuttal-heads",
checkpoint_dir,
extra_args=dict(
loss_leak=0.0)
),
**get_all_models_in_dir(
"11-8-23",
checkpoint_dir,
extra_args=dict(loss_leak=0.0)),
**get_all_models_in_dir(
"10-30-23-3",
checkpoint_dir,
extra_args=dict(loss_leak=0.0)),
"davenet": dict(
chkpt_name=None,
extra_args=dict(
audio_blur=1,
image_model_type="davenet",
image_aligner_type=None,
audio_model_type="davenet",
audio_aligner_type=None,
audio_input="davenet_spec",
use_cached_embs=False,
dropout=False,
sim_agg_heads=1,
nonneg_sim=False,
audio_lora=False,
image_lora=False,
norm_vectors=False,
),
data_args=dict(
use_cached_embs=False,
use_davenet_spec=True,
override_target_length=20,
audio_model_type="davenet",
),
),
"cavmae": dict(
chkpt_name=None,
extra_args=dict(
audio_blur=1,
image_model_type="cavmae",
image_aligner_type=None,
audio_model_type="cavmae",
audio_aligner_type=None,
audio_input="spec",
use_cached_embs=False,
sim_agg_heads=1,
dropout=False,
nonneg_sim=False,
audio_lora=False,
image_lora=False,
norm_vectors=False,
learn_audio_cls=False,
sim_agg_type="cavmae",
),
data_args=dict(
use_cached_embs=False,
use_davenet_spec=True,
audio_model_type="cavmae",
override_target_length=10,
),
),
"imagebind": dict(
chkpt_name=None,
extra_args=dict(
audio_blur=1,
image_model_type="imagebind",
image_aligner_type=None,
audio_model_type="imagebind",
audio_aligner_type=None,
audio_input="spec",
use_cached_embs=False,
sim_agg_heads=1,
dropout=False,
nonneg_sim=False,
audio_lora=False,
image_lora=False,
norm_vectors=False,
learn_audio_cls=False,
sim_agg_type="imagebind",
),
data_args=dict(
use_cached_embs=False,
use_davenet_spec=True,
audio_model_type="imagebind",
override_target_length=10,
),
),
}
model_info["denseav_language"] = model_info["10-30-23-3/places_base/train"]
model_info["denseav_sound"] = model_info["11-8-23/hubert_1h_asf_cls_full_image_train_small_lr/train"]
model_info["denseav_2head"] = model_info["1-23-24-rebuttal-heads/mixed-2h/train"]
return model_info