import torch from torch import nn from functools import reduce from pathlib import Path from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig from ema_pytorch import EMA def exists(val): return val is not None def safeget(dictionary, keys, default = None): return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary) def load_imagen_from_checkpoint( checkpoint_path, load_weights = True, load_ema_if_available = False ): model_path = Path(checkpoint_path) full_model_path = str(model_path.resolve()) assert model_path.exists(), f'checkpoint not found at {full_model_path}' loaded = torch.load(str(model_path), map_location='cpu') imagen_params = safeget(loaded, 'imagen_params') imagen_type = safeget(loaded, 'imagen_type') if imagen_type == 'original': imagen_klass = ImagenConfig elif imagen_type == 'elucidated': imagen_klass = ElucidatedImagenConfig else: raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig') assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint' imagen = imagen_klass(**imagen_params).create() if not load_weights: return imagen has_ema = 'ema' in loaded should_load_ema = has_ema and load_ema_if_available imagen.load_state_dict(loaded['model']) if not should_load_ema: print('loading non-EMA version of unets') return imagen ema_unets = nn.ModuleList([]) for unet in imagen.unets: ema_unets.append(EMA(unet)) ema_unets.load_state_dict(loaded['ema']) for unet, ema_unet in zip(imagen.unets, ema_unets): unet.load_state_dict(ema_unet.ema_model.state_dict()) print('loaded EMA version of unets') return imagen