Spaces:
Building
Building
from gyraudio.audio_separation.properties import SHORT_NAME, MODEL, OPTIMIZER, CURRENT_EPOCH, CONFIGURATION | |
from pathlib import Path | |
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT | |
import logging | |
import torch | |
def get_output_folder(config: dict, root_dir: Path = EXPERIMENT_STORAGE_ROOT, override: bool = False) -> Path: | |
output_folder = root_dir/config["short_name"] | |
exists = False | |
if output_folder.exists(): | |
if not override: | |
logging.info(f"Experiment {config[SHORT_NAME]} already exists. Override is set to False. Skipping.") | |
if override: | |
logging.warning(f"Experiment {config[SHORT_NAME]} will be OVERRIDDEN") | |
exists = True | |
else: | |
output_folder.mkdir(parents=True, exist_ok=True) | |
exists = True | |
return exists, output_folder | |
def checkpoint_paths(exp_dir: Path, epoch=None): | |
if epoch is None: | |
checkpoints = sorted(exp_dir.glob("model_*.pt")) | |
assert len(checkpoints) > 0, f"No checkpoints found in {exp_dir}" | |
model_checkpoint = checkpoints[-1] | |
epoch = int(model_checkpoint.stem.split("_")[-1]) | |
optimizer_checkpoint = exp_dir/model_checkpoint.stem.replace("model", "optimizer") | |
else: | |
model_checkpoint = exp_dir/f"model_{epoch:04d}.pt" | |
optimizer_checkpoint = exp_dir/f"optimizer_{epoch:04d}.pt" | |
return model_checkpoint, optimizer_checkpoint, epoch | |
def load_checkpoint(model, exp_dir: Path, optimizer=None, epoch: int = None, | |
device="cuda" if torch.cuda.is_available() else "cpu"): | |
config = {} | |
model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch) | |
model_state_dict = torch.load(model_checkpoint, map_location=torch.device(device)) | |
model.load_state_dict(model_state_dict[MODEL]) | |
if optimizer is not None: | |
optimizer_state_dict = torch.load(optimizer_checkpoint, map_location=torch.device(device)) | |
optimizer.load_state_dict(optimizer_state_dict[OPTIMIZER]) | |
config = optimizer_state_dict[CONFIGURATION] | |
return model, optimizer, epoch, config | |
def save_checkpoint(model, exp_dir: Path, optimizer=None, config: dict = {}, epoch: int = None): | |
model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch) | |
torch.save( | |
{ | |
MODEL: model.state_dict(), | |
}, | |
model_checkpoint | |
) | |
torch.save( | |
{ | |
CURRENT_EPOCH: epoch, | |
CONFIGURATION: config, | |
OPTIMIZER: optimizer.state_dict() | |
}, | |
optimizer_checkpoint | |
) | |
print(f"Checkpoint saved:\n - model: {model_checkpoint}\n - checkpoint: {optimizer_checkpoint}") | |