ChongMou commited on
Commit
4f29a2b
·
1 Parent(s): 0cf589a

Update demo/model.py

Browse files
Files changed (1) hide show
  1. demo/model.py +7 -8
demo/model.py CHANGED
@@ -104,8 +104,7 @@ class Model_all:
104
  self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
105
  self.config.model.params.cond_stage_config.params.device = device
106
  self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
107
- self.current_base_pose = 'sd-v1-4.ckpt'
108
- self.current_base_sketch = 'sd-v1-4.ckpt'
109
  self.sampler = PLMSSampler(self.base_model)
110
 
111
  # sketch part
@@ -144,7 +143,7 @@ class Model_all:
144
 
145
  @torch.no_grad()
146
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
147
- if self.current_base_sketch != base_model:
148
  ckpt = os.path.join("models", base_model)
149
  pl_sd = torch.load(ckpt, map_location="cpu")
150
  if "state_dict" in pl_sd:
@@ -152,7 +151,7 @@ class Model_all:
152
  else:
153
  sd = pl_sd
154
  self.base_model.load_state_dict(sd, strict=False)
155
- self.current_base_sketch = base_model
156
  # del sd
157
  # del pl_sd
158
  con_strength = int((1-con_strength)*50)
@@ -218,7 +217,7 @@ class Model_all:
218
 
219
  @torch.no_grad()
220
  def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
221
- if self.current_base_sketch != base_model:
222
  ckpt = os.path.join("models", base_model)
223
  pl_sd = torch.load(ckpt, map_location="cpu")
224
  if "state_dict" in pl_sd:
@@ -226,7 +225,7 @@ class Model_all:
226
  else:
227
  sd = pl_sd
228
  self.base_model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
229
- self.current_base_sketch = base_model
230
  con_strength = int((1-con_strength)*50)
231
  if fix_sample == 'True':
232
  seed_everything(42)
@@ -288,7 +287,7 @@ class Model_all:
288
 
289
  @torch.no_grad()
290
  def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
291
- if self.current_base_pose != base_model:
292
  ckpt = os.path.join("models", base_model)
293
  pl_sd = torch.load(ckpt, map_location="cpu")
294
  if "state_dict" in pl_sd:
@@ -296,7 +295,7 @@ class Model_all:
296
  else:
297
  sd = pl_sd
298
  self.base_model.load_state_dict(sd, strict=False)
299
- self.current_base_pose = base_model
300
  con_strength = int((1-con_strength)*50)
301
  if fix_sample == 'True':
302
  seed_everything(42)
 
104
  self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
105
  self.config.model.params.cond_stage_config.params.device = device
106
  self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
107
+ self.current_base = 'sd-v1-4.ckpt'
 
108
  self.sampler = PLMSSampler(self.base_model)
109
 
110
  # sketch part
 
143
 
144
  @torch.no_grad()
145
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
146
+ if self.current_base != base_model:
147
  ckpt = os.path.join("models", base_model)
148
  pl_sd = torch.load(ckpt, map_location="cpu")
149
  if "state_dict" in pl_sd:
 
151
  else:
152
  sd = pl_sd
153
  self.base_model.load_state_dict(sd, strict=False)
154
+ self.current_base = base_model
155
  # del sd
156
  # del pl_sd
157
  con_strength = int((1-con_strength)*50)
 
217
 
218
  @torch.no_grad()
219
  def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
220
+ if self.current_base != base_model:
221
  ckpt = os.path.join("models", base_model)
222
  pl_sd = torch.load(ckpt, map_location="cpu")
223
  if "state_dict" in pl_sd:
 
225
  else:
226
  sd = pl_sd
227
  self.base_model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
228
+ self.current_base = base_model
229
  con_strength = int((1-con_strength)*50)
230
  if fix_sample == 'True':
231
  seed_everything(42)
 
287
 
288
  @torch.no_grad()
289
  def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
290
+ if self.current_base != base_model:
291
  ckpt = os.path.join("models", base_model)
292
  pl_sd = torch.load(ckpt, map_location="cpu")
293
  if "state_dict" in pl_sd:
 
295
  else:
296
  sd = pl_sd
297
  self.base_model.load_state_dict(sd, strict=False)
298
+ self.current_base = base_model
299
  con_strength = int((1-con_strength)*50)
300
  if fix_sample == 'True':
301
  seed_everything(42)