Spaces:
Paused
Paused
from clip import CLIP | |
from encoder import VAE_Encoder | |
from decoder import VAE_Decoder | |
from diffusion import Diffusion | |
import model_converter | |
import torch | |
def load_models(ckpt_path, device, use_se=False): | |
state_dict = model_converter.load_from_standard_weights(ckpt_path, device) | |
encoder = VAE_Encoder().to(device) | |
encoder.load_state_dict(state_dict['encoder'], strict=True) | |
decoder = VAE_Decoder().to(device) | |
decoder.load_state_dict(state_dict['decoder'], strict=True) | |
# Initialize diffusion model with SE blocks disabled for loading pre-trained weights | |
diffusion = Diffusion(use_se=False).to(device) | |
diffusion.load_state_dict(state_dict['diffusion'], strict=True) | |
# If SE blocks are requested, reinitialize the model with them | |
if use_se: | |
diffusion = Diffusion(use_se=True).to(device) | |
# Copy the weights from the loaded model | |
with torch.no_grad(): | |
for name, param in diffusion.named_parameters(): | |
if 'se' not in name: # Skip SE block parameters | |
if name in state_dict['diffusion']: | |
param.copy_(state_dict['diffusion'][name]) | |
clip = CLIP().to(device) | |
clip.load_state_dict(state_dict['clip'], strict=True) | |
return { | |
'clip': clip, | |
'encoder': encoder, | |
'decoder': decoder, | |
'diffusion': diffusion, | |
} |