Adapter commited on
Commit
fc5fab1
·
1 Parent(s): 1e3fd43
Files changed (2) hide show
  1. app.py +1 -1
  2. demo/model.py +14 -50
app.py CHANGED
@@ -17,7 +17,7 @@ from huggingface_hub import hf_hub_url
17
  urls = {
18
  'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth'],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
- 'andite/anything-v4.0':['anything-v4.0-pruned.ckpt'],
21
  }
22
  urls_mmpose = [
23
  'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
 
17
  urls = {
18
  'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth'],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
+ 'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
21
  }
22
  urls_mmpose = [
23
  'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
demo/model.py CHANGED
@@ -149,6 +149,11 @@ class Model_all:
149
  [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
150
  [51, 153, 255],
151
  [51, 153, 255], [51, 153, 255], [51, 153, 255]]
 
 
 
 
 
152
 
153
  @torch.no_grad()
154
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
@@ -160,12 +165,11 @@ class Model_all:
160
  sd = pl_sd["state_dict"]
161
  else:
162
  sd = pl_sd
163
- # self.base_model = self.base_model.cpu()
164
  self.base_model.load_state_dict(sd, strict=False)
165
- # self.base_model = self.base_model.cuda()
166
  self.current_base = base_model
167
- # del sd
168
- # del pl_sd
 
169
  con_strength = int((1 - con_strength) * 50)
170
  if fix_sample == 'True':
171
  seed_everything(42)
@@ -185,23 +189,12 @@ class Model_all:
185
  im = im.float()
186
  im_edge = tensor2img(im)
187
 
188
- # # save gpu memory
189
- # self.base_model.model = self.base_model.model.cpu()
190
- # self.model_sketch = self.model_sketch.cuda()
191
- # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
192
- # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
193
-
194
  # extract condition features
195
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
196
  nc = self.base_model.get_learned_conditioning([neg_prompt])
197
  features_adapter = self.model_sketch(im.to(self.device))
198
  shape = [4, 64, 64]
199
 
200
- # # save gpu memory
201
- # self.model_sketch = self.model_sketch.cpu()
202
- # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
203
- # self.base_model.model = self.base_model.model.cuda()
204
-
205
  # sampling
206
  samples_ddim, _ = self.sampler.sample(S=50,
207
  conditioning=c,
@@ -215,8 +208,6 @@ class Model_all:
215
  features_adapter1=features_adapter,
216
  mode='sketch',
217
  con_strength=con_strength)
218
- # # save gpu memory
219
- # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
220
 
221
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
222
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -236,10 +227,11 @@ class Model_all:
236
  sd = pl_sd["state_dict"]
237
  else:
238
  sd = pl_sd
239
- # self.base_model = self.base_model.cpu()
240
  self.base_model.load_state_dict(sd, strict=False)
241
- # self.base_model = self.base_model.cuda()
242
  self.current_base = base_model
 
 
 
243
  con_strength = int((1 - con_strength) * 50)
244
  if fix_sample == 'True':
245
  seed_everything(42)
@@ -250,29 +242,17 @@ class Model_all:
250
  im = im.clip(0, 255).astype(np.uint8)
251
  im = cv2.resize(im, (512, 512))
252
 
253
- # im = 255-im
254
  im_edge = im.copy()
255
  im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
256
  im = im > 0.5
257
  im = im.float()
258
 
259
- # # save gpu memory
260
- # self.base_model.model = self.base_model.model.cpu()
261
- # self.model_sketch = self.model_sketch.cuda()
262
- # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
263
- # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
264
-
265
  # extract condition features
266
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
267
  nc = self.base_model.get_learned_conditioning([neg_prompt])
268
  features_adapter = self.model_sketch(im.to(self.device))
269
  shape = [4, 64, 64]
270
 
271
- # # save gpu memory
272
- # self.model_sketch = self.model_sketch.cpu()
273
- # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
274
- # self.base_model.model = self.base_model.model.cuda()
275
-
276
  # sampling
277
  samples_ddim, _ = self.sampler.sample(S=50,
278
  conditioning=c,
@@ -287,9 +267,6 @@ class Model_all:
287
  mode='sketch',
288
  con_strength=con_strength)
289
 
290
- # # save gpu memory
291
- # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
292
-
293
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
294
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
295
  x_samples_ddim = x_samples_ddim.to('cpu')
@@ -309,10 +286,11 @@ class Model_all:
309
  sd = pl_sd["state_dict"]
310
  else:
311
  sd = pl_sd
312
- # self.base_model = self.base_model.cpu()
313
  self.base_model.load_state_dict(sd, strict=False)
314
- # self.base_model = self.base_model.cuda()
315
  self.current_base = base_model
 
 
 
316
  con_strength = int((1 - con_strength) * 50)
317
  if fix_sample == 'True':
318
  seed_everything(42)
@@ -356,12 +334,6 @@ class Model_all:
356
  thickness=2)
357
  im_pose = cv2.resize(im_pose, (512, 512))
358
 
359
- # # save gpu memory
360
- # self.base_model.model = self.base_model.model.cpu()
361
- # self.model_pose = self.model_pose.cuda()
362
- # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
363
- # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
364
-
365
  # extract condition features
366
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
367
  nc = self.base_model.get_learned_conditioning([neg_prompt])
@@ -369,11 +341,6 @@ class Model_all:
369
  pose = pose.unsqueeze(0)
370
  features_adapter = self.model_pose(pose.to(self.device))
371
 
372
- # # save gpu memory
373
- # self.model_pose = self.model_pose.cpu()
374
- # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
375
- # self.base_model.model = self.base_model.model.cuda()
376
-
377
  shape = [4, 64, 64]
378
 
379
  # sampling
@@ -390,9 +357,6 @@ class Model_all:
390
  mode='sketch',
391
  con_strength=con_strength)
392
 
393
- # # save gpu memory
394
- # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
395
-
396
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
397
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
398
  x_samples_ddim = x_samples_ddim.to('cpu')
 
149
  [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
150
  [51, 153, 255],
151
  [51, 153, 255], [51, 153, 255], [51, 153, 255]]
152
+
153
+ def load_vae(self):
154
+ vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
155
+ sd = vae_sd["state_dict"]
156
+ self.base_model.first_stage_model.load_state_dict(sd, strict=False)
157
 
158
  @torch.no_grad()
159
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
 
165
  sd = pl_sd["state_dict"]
166
  else:
167
  sd = pl_sd
 
168
  self.base_model.load_state_dict(sd, strict=False)
 
169
  self.current_base = base_model
170
+ if 'anything' in base_model.lower():
171
+ self.load_vae()
172
+
173
  con_strength = int((1 - con_strength) * 50)
174
  if fix_sample == 'True':
175
  seed_everything(42)
 
189
  im = im.float()
190
  im_edge = tensor2img(im)
191
 
 
 
 
 
 
 
192
  # extract condition features
193
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
194
  nc = self.base_model.get_learned_conditioning([neg_prompt])
195
  features_adapter = self.model_sketch(im.to(self.device))
196
  shape = [4, 64, 64]
197
 
 
 
 
 
 
198
  # sampling
199
  samples_ddim, _ = self.sampler.sample(S=50,
200
  conditioning=c,
 
208
  features_adapter1=features_adapter,
209
  mode='sketch',
210
  con_strength=con_strength)
 
 
211
 
212
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
213
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
227
  sd = pl_sd["state_dict"]
228
  else:
229
  sd = pl_sd
 
230
  self.base_model.load_state_dict(sd, strict=False)
 
231
  self.current_base = base_model
232
+ if 'anything' in base_model.lower():
233
+ self.load_vae()
234
+
235
  con_strength = int((1 - con_strength) * 50)
236
  if fix_sample == 'True':
237
  seed_everything(42)
 
242
  im = im.clip(0, 255).astype(np.uint8)
243
  im = cv2.resize(im, (512, 512))
244
 
 
245
  im_edge = im.copy()
246
  im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
247
  im = im > 0.5
248
  im = im.float()
249
 
 
 
 
 
 
 
250
  # extract condition features
251
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
252
  nc = self.base_model.get_learned_conditioning([neg_prompt])
253
  features_adapter = self.model_sketch(im.to(self.device))
254
  shape = [4, 64, 64]
255
 
 
 
 
 
 
256
  # sampling
257
  samples_ddim, _ = self.sampler.sample(S=50,
258
  conditioning=c,
 
267
  mode='sketch',
268
  con_strength=con_strength)
269
 
 
 
 
270
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
271
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
272
  x_samples_ddim = x_samples_ddim.to('cpu')
 
286
  sd = pl_sd["state_dict"]
287
  else:
288
  sd = pl_sd
 
289
  self.base_model.load_state_dict(sd, strict=False)
 
290
  self.current_base = base_model
291
+ if 'anything' in base_model.lower():
292
+ self.load_vae()
293
+
294
  con_strength = int((1 - con_strength) * 50)
295
  if fix_sample == 'True':
296
  seed_everything(42)
 
334
  thickness=2)
335
  im_pose = cv2.resize(im_pose, (512, 512))
336
 
 
 
 
 
 
 
337
  # extract condition features
338
  c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
339
  nc = self.base_model.get_learned_conditioning([neg_prompt])
 
341
  pose = pose.unsqueeze(0)
342
  features_adapter = self.model_pose(pose.to(self.device))
343
 
 
 
 
 
 
344
  shape = [4, 64, 64]
345
 
346
  # sampling
 
357
  mode='sketch',
358
  con_strength=con_strength)
359
 
 
 
 
360
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
361
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
362
  x_samples_ddim = x_samples_ddim.to('cpu')