Spaces:
Sleeping
Sleeping
Commit
·
20c21fe
1
Parent(s):
05469a1
nooooooo se
Browse files- app.py +2 -2
- src/model_loader.py +3 -13
app.py
CHANGED
@@ -46,8 +46,8 @@ config = Config(
|
|
46 |
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
47 |
)
|
48 |
|
49 |
-
# Load models
|
50 |
-
config.models = model_loader.load_models(str(model_file), device
|
51 |
|
52 |
MAX_SEED = np.iinfo(np.int32).max
|
53 |
MAX_IMAGE_SIZE = 1024
|
|
|
46 |
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
47 |
)
|
48 |
|
49 |
+
# Load models
|
50 |
+
config.models = model_loader.load_models(str(model_file), device)
|
51 |
|
52 |
MAX_SEED = np.iinfo(np.int32).max
|
53 |
MAX_IMAGE_SIZE = 1024
|
src/model_loader.py
CHANGED
@@ -6,7 +6,7 @@ from .diffusion import Diffusion
|
|
6 |
from . import model_converter
|
7 |
import torch
|
8 |
|
9 |
-
def load_models(ckpt_path, device
|
10 |
state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
|
11 |
|
12 |
encoder = VAE_Encoder().to(device)
|
@@ -15,19 +15,9 @@ def load_models(ckpt_path, device, use_se=False):
|
|
15 |
decoder = VAE_Decoder().to(device)
|
16 |
decoder.load_state_dict(state_dict['decoder'], strict=True)
|
17 |
|
18 |
-
# Initialize diffusion model
|
19 |
-
diffusion = Diffusion(
|
20 |
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
21 |
-
|
22 |
-
# If SE blocks are requested, reinitialize the model with them
|
23 |
-
if use_se:
|
24 |
-
diffusion = Diffusion(use_se=True).to(device)
|
25 |
-
# Copy the weights from the loaded model
|
26 |
-
with torch.no_grad():
|
27 |
-
for name, param in diffusion.named_parameters():
|
28 |
-
if 'se' not in name: # Skip SE block parameters
|
29 |
-
if name in state_dict['diffusion']:
|
30 |
-
param.copy_(state_dict['diffusion'][name])
|
31 |
|
32 |
clip = CLIP().to(device)
|
33 |
clip.load_state_dict(state_dict['clip'], strict=True)
|
|
|
6 |
from . import model_converter
|
7 |
import torch
|
8 |
|
9 |
+
def load_models(ckpt_path, device):
|
10 |
state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
|
11 |
|
12 |
encoder = VAE_Encoder().to(device)
|
|
|
15 |
decoder = VAE_Decoder().to(device)
|
16 |
decoder.load_state_dict(state_dict['decoder'], strict=True)
|
17 |
|
18 |
+
# Initialize diffusion model
|
19 |
+
diffusion = Diffusion().to(device)
|
20 |
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
clip = CLIP().to(device)
|
23 |
clip.load_state_dict(state_dict['clip'], strict=True)
|