torinriley commited on
Commit
20c21fe
·
1 Parent(s): 05469a1

nooooooo se

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. src/model_loader.py +3 -13
app.py CHANGED
@@ -46,8 +46,8 @@ config = Config(
46
  tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
47
  )
48
 
49
- # Load models with SE blocks enabled
50
- config.models = model_loader.load_models(str(model_file), device, use_se=True)
51
 
52
  MAX_SEED = np.iinfo(np.int32).max
53
  MAX_IMAGE_SIZE = 1024
 
46
  tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
47
  )
48
 
49
+ # Load models
50
+ config.models = model_loader.load_models(str(model_file), device)
51
 
52
  MAX_SEED = np.iinfo(np.int32).max
53
  MAX_IMAGE_SIZE = 1024
src/model_loader.py CHANGED
@@ -6,7 +6,7 @@ from .diffusion import Diffusion
6
  from . import model_converter
7
  import torch
8
 
9
- def load_models(ckpt_path, device, use_se=False):
10
  state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
11
 
12
  encoder = VAE_Encoder().to(device)
@@ -15,19 +15,9 @@ def load_models(ckpt_path, device, use_se=False):
15
  decoder = VAE_Decoder().to(device)
16
  decoder.load_state_dict(state_dict['decoder'], strict=True)
17
 
18
- # Initialize diffusion model with SE blocks disabled for loading pre-trained weights
19
- diffusion = Diffusion(use_se=False).to(device)
20
  diffusion.load_state_dict(state_dict['diffusion'], strict=True)
21
-
22
- # If SE blocks are requested, reinitialize the model with them
23
- if use_se:
24
- diffusion = Diffusion(use_se=True).to(device)
25
- # Copy the weights from the loaded model
26
- with torch.no_grad():
27
- for name, param in diffusion.named_parameters():
28
- if 'se' not in name: # Skip SE block parameters
29
- if name in state_dict['diffusion']:
30
- param.copy_(state_dict['diffusion'][name])
31
 
32
  clip = CLIP().to(device)
33
  clip.load_state_dict(state_dict['clip'], strict=True)
 
6
  from . import model_converter
7
  import torch
8
 
9
+ def load_models(ckpt_path, device):
10
  state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
11
 
12
  encoder = VAE_Encoder().to(device)
 
15
  decoder = VAE_Decoder().to(device)
16
  decoder.load_state_dict(state_dict['decoder'], strict=True)
17
 
18
+ # Initialize diffusion model
19
+ diffusion = Diffusion().to(device)
20
  diffusion.load_state_dict(state_dict['diffusion'], strict=True)
 
 
 
 
 
 
 
 
 
 
21
 
22
  clip = CLIP().to(device)
23
  clip.load_state_dict(state_dict['clip'], strict=True)