File size: 1,388 Bytes
b876688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion

import model_converter
import torch

def load_models(ckpt_path, device, use_se=False):
    state_dict = model_converter.load_from_standard_weights(ckpt_path, device)

    encoder = VAE_Encoder().to(device)
    encoder.load_state_dict(state_dict['encoder'], strict=True)

    decoder = VAE_Decoder().to(device)
    decoder.load_state_dict(state_dict['decoder'], strict=True)

    # Initialize diffusion model with SE blocks disabled for loading pre-trained weights
    diffusion = Diffusion(use_se=False).to(device)
    diffusion.load_state_dict(state_dict['diffusion'], strict=True)
    
    # If SE blocks are requested, reinitialize the model with them
    if use_se:
        diffusion = Diffusion(use_se=True).to(device)
        # Copy the weights from the loaded model
        with torch.no_grad():
            for name, param in diffusion.named_parameters():
                if 'se' not in name:  # Skip SE block parameters
                    if name in state_dict['diffusion']:
                        param.copy_(state_dict['diffusion'][name])

    clip = CLIP().to(device)
    clip.load_state_dict(state_dict['clip'], strict=True)

    return {
        'clip': clip,
        'encoder': encoder,
        'decoder': decoder,
        'diffusion': diffusion,
    }