Spaces:
Runtime error
Runtime error
import sys, os | |
import torch | |
from .vqvae import VQ, VQVAE, DiVAE, VQControlNet | |
from .scheduling import * | |
def get_image_tokenizer(tokenizer_id: str, | |
tokenizers_root: str = './tokenizer_ckpts', | |
encoder_only: bool = False, | |
device: str = 'cuda', | |
verbose: bool = True, | |
return_None_on_fail: bool = False,): | |
""" | |
Load a pretrained image tokenizer from a checkpoint. | |
Args: | |
tokenizer_id (str): ID of the tokenizer to load (name of the checkpoint file without ".pth"). | |
tokenizers_root (str): Path to the directory containing the tokenizer checkpoints. | |
encoder_only (bool): Set to True to load only the encoder part of the tokenizer. | |
device (str): Device to load the tokenizer on. | |
verbose (bool): Set to True to print load_state_dict warning/success messages | |
return_None_on_fail (bool): Set to True to return None if the tokenizer fails to load (e.g. doesn't exist) | |
Returns: | |
model (nn.Module): The loaded tokenizer. | |
""" | |
if return_None_on_fail and not os.path.exists(os.path.join(tokenizers_root, f'{tokenizer_id}.pth')): | |
return None | |
if verbose: | |
print(f'Loading tokenizer {tokenizer_id} ... ', end='') | |
ckpt = torch.load(os.path.join(tokenizers_root, f'{tokenizer_id}.pth'), map_location='cpu') | |
# Handle renamed arguments | |
if 'CLIP' in ckpt['args'].domain or 'DINO' in ckpt['args'].domain or 'ImageBind' in ckpt['args'].domain: | |
ckpt['args'].patch_proj = False | |
elif 'sam' in ckpt['args'].domain: | |
ckpt['args'].input_size_min = ckpt['args'].mask_size | |
ckpt['args'].input_size_max = ckpt['args'].mask_size | |
ckpt['args'].input_size = ckpt['args'].mask_size | |
ckpt['args'].quant_type = getattr(ckpt['args'], 'quantizer_type', None) | |
ckpt['args'].enc_type = getattr(ckpt['args'], 'encoder_type', None) | |
ckpt['args'].dec_type = getattr(ckpt['args'], 'decoder_type', None) | |
ckpt['args'].image_size = getattr(ckpt['args'], 'input_size', None) or getattr(ckpt['args'], 'input_size_max', None) | |
ckpt['args'].image_size_enc = getattr(ckpt['args'], 'input_size_enc', None) | |
ckpt['args'].image_size_dec = getattr(ckpt['args'], 'input_size_dec', None) | |
ckpt['args'].image_size_sd = getattr(ckpt['args'], 'input_size_sd', None) | |
ckpt['args'].ema_decay = getattr(ckpt['args'], 'quantizer_ema_decay', None) | |
ckpt['args'].enable_xformer = getattr(ckpt['args'], 'use_xformer', None) | |
if 'cls_emb.weight' in ckpt['model']: | |
ckpt['args'].n_labels, ckpt['args'].n_channels = n_labels, n_channels = ckpt['model']['cls_emb.weight'].shape | |
elif 'encoder.linear_in.weight' in ckpt['model']: | |
ckpt['args'].n_channels = ckpt['model']['encoder.linear_in.weight'].shape[1] | |
else: | |
ckpt['args'].n_channels = ckpt['model']['encoder.proj.weight'].shape[1] | |
ckpt['args'].sync_codebook = False | |
if encoder_only: | |
model_type = VQ | |
ckpt['model'] = {k: v for k, v in ckpt['model'].items() if 'decoder' not in k and 'post_quant_proj' not in k} | |
else: | |
# TODO: Add the model type to the checkpoint when training so we can avoid this hackery | |
if any(['controlnet' in k for k in ckpt['model'].keys()]): | |
ckpt['args'].model_type = 'VQControlNet' | |
elif hasattr(ckpt['args'], 'beta_schedule'): | |
ckpt['args'].model_type = 'DiVAE' | |
else: | |
ckpt['args'].model_type = 'VQVAE' | |
model_type = getattr(sys.modules[__name__], ckpt['args'].model_type) | |
model = model_type(**vars(ckpt['args'])) | |
msg = model.load_state_dict(ckpt['model'], strict=False) | |
if verbose: | |
print(msg) | |
return model.to(device).eval(), ckpt['args'] | |