Spaces:
Runtime error
Runtime error
ChongMou
commited on
Commit
·
4f29a2b
1
Parent(s):
0cf589a
Update demo/model.py
Browse files- demo/model.py +7 -8
demo/model.py
CHANGED
@@ -104,8 +104,7 @@ class Model_all:
|
|
104 |
self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
|
105 |
self.config.model.params.cond_stage_config.params.device = device
|
106 |
self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
|
107 |
-
self.
|
108 |
-
self.current_base_sketch = 'sd-v1-4.ckpt'
|
109 |
self.sampler = PLMSSampler(self.base_model)
|
110 |
|
111 |
# sketch part
|
@@ -144,7 +143,7 @@ class Model_all:
|
|
144 |
|
145 |
@torch.no_grad()
|
146 |
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
147 |
-
if self.
|
148 |
ckpt = os.path.join("models", base_model)
|
149 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
150 |
if "state_dict" in pl_sd:
|
@@ -152,7 +151,7 @@ class Model_all:
|
|
152 |
else:
|
153 |
sd = pl_sd
|
154 |
self.base_model.load_state_dict(sd, strict=False)
|
155 |
-
self.
|
156 |
# del sd
|
157 |
# del pl_sd
|
158 |
con_strength = int((1-con_strength)*50)
|
@@ -218,7 +217,7 @@ class Model_all:
|
|
218 |
|
219 |
@torch.no_grad()
|
220 |
def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
221 |
-
if self.
|
222 |
ckpt = os.path.join("models", base_model)
|
223 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
224 |
if "state_dict" in pl_sd:
|
@@ -226,7 +225,7 @@ class Model_all:
|
|
226 |
else:
|
227 |
sd = pl_sd
|
228 |
self.base_model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
|
229 |
-
self.
|
230 |
con_strength = int((1-con_strength)*50)
|
231 |
if fix_sample == 'True':
|
232 |
seed_everything(42)
|
@@ -288,7 +287,7 @@ class Model_all:
|
|
288 |
|
289 |
@torch.no_grad()
|
290 |
def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
291 |
-
if self.
|
292 |
ckpt = os.path.join("models", base_model)
|
293 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
294 |
if "state_dict" in pl_sd:
|
@@ -296,7 +295,7 @@ class Model_all:
|
|
296 |
else:
|
297 |
sd = pl_sd
|
298 |
self.base_model.load_state_dict(sd, strict=False)
|
299 |
-
self.
|
300 |
con_strength = int((1-con_strength)*50)
|
301 |
if fix_sample == 'True':
|
302 |
seed_everything(42)
|
|
|
104 |
self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
|
105 |
self.config.model.params.cond_stage_config.params.device = device
|
106 |
self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
|
107 |
+
self.current_base = 'sd-v1-4.ckpt'
|
|
|
108 |
self.sampler = PLMSSampler(self.base_model)
|
109 |
|
110 |
# sketch part
|
|
|
143 |
|
144 |
@torch.no_grad()
|
145 |
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
146 |
+
if self.current_base != base_model:
|
147 |
ckpt = os.path.join("models", base_model)
|
148 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
149 |
if "state_dict" in pl_sd:
|
|
|
151 |
else:
|
152 |
sd = pl_sd
|
153 |
self.base_model.load_state_dict(sd, strict=False)
|
154 |
+
self.current_base = base_model
|
155 |
# del sd
|
156 |
# del pl_sd
|
157 |
con_strength = int((1-con_strength)*50)
|
|
|
217 |
|
218 |
@torch.no_grad()
|
219 |
def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
220 |
+
if self.current_base != base_model:
|
221 |
ckpt = os.path.join("models", base_model)
|
222 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
223 |
if "state_dict" in pl_sd:
|
|
|
225 |
else:
|
226 |
sd = pl_sd
|
227 |
self.base_model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
|
228 |
+
self.current_base = base_model
|
229 |
con_strength = int((1-con_strength)*50)
|
230 |
if fix_sample == 'True':
|
231 |
seed_everything(42)
|
|
|
287 |
|
288 |
@torch.no_grad()
|
289 |
def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
290 |
+
if self.current_base != base_model:
|
291 |
ckpt = os.path.join("models", base_model)
|
292 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
293 |
if "state_dict" in pl_sd:
|
|
|
295 |
else:
|
296 |
sd = pl_sd
|
297 |
self.base_model.load_state_dict(sd, strict=False)
|
298 |
+
self.current_base = base_model
|
299 |
con_strength = int((1-con_strength)*50)
|
300 |
if fix_sample == 'True':
|
301 |
seed_everything(42)
|