RamAnanth1 commited on
Commit
b6c5945
·
1 Parent(s): 9d3a8c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -551
app.py CHANGED
@@ -20,6 +20,8 @@ from langchain.llms.openai import OpenAI
20
  import re
21
  import uuid
22
  from diffusers import StableDiffusionInpaintPipeline
 
 
23
  from PIL import Image
24
  import numpy as np
25
  from omegaconf import OmegaConf
@@ -28,16 +30,6 @@ import cv2
28
  import einops
29
  from pytorch_lightning import seed_everything
30
  import random
31
- from ldm.util import instantiate_from_config
32
- from ControlNet.cldm.model import create_model, load_state_dict
33
- from ControlNet.cldm.ddim_hacked import DDIMSampler
34
- from ControlNet.annotator.canny import CannyDetector
35
- from ControlNet.annotator.mlsd import MLSDdetector
36
- from ControlNet.annotator.util import HWC3, resize_image
37
- from ControlNet.annotator.hed import HEDdetector, nms
38
- from ControlNet.annotator.openpose import OpenposeDetector
39
- from ControlNet.annotator.uniformer import UniformerDetector
40
- from ControlNet.annotator.midas import MidasDetector
41
 
42
  VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
43
 
@@ -223,7 +215,6 @@ class ImageCaptioning:
223
  class image2canny:
224
  def __init__(self):
225
  print("Direct detect canny.")
226
- self.detector = CannyDetector()
227
  self.low_thresh = 100
228
  self.high_thresh = 200
229
 
@@ -231,558 +222,58 @@ class image2canny:
231
  print("===>Starting image2canny Inference")
232
  image = Image.open(inputs)
233
  image = np.array(image)
234
- canny = self.detector(image, self.low_thresh, self.high_thresh)
235
- canny = 255 - canny
236
- image = Image.fromarray(canny)
 
 
237
  updated_image_path = get_new_image_name(inputs, func_name="edge")
238
- image.save(updated_image_path)
239
  return updated_image_path
240
 
241
  class canny2image:
242
  def __init__(self, device):
243
  print("Initialize the canny2image model.")
244
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
245
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_canny.pth', location='cpu'))
246
- self.model = model.to(device)
247
- self.device = device
248
- self.ddim_sampler = DDIMSampler(self.model)
249
- self.ddim_steps = 20
250
- self.image_resolution = 512
251
- self.num_samples = 1
252
- self.save_memory = False
253
- self.strength = 1.0
254
- self.guess_mode = False
255
- self.scale = 9.0
256
- self.seed = -1
257
- self.a_prompt = 'best quality, extremely detailed'
258
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
259
-
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  def inference(self, inputs):
261
  print("===>Starting canny2image Inference")
262
  image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
263
  image = Image.open(image_path)
264
  image = np.array(image)
265
- image = 255 - image
266
  prompt = instruct_text
267
- img = resize_image(HWC3(image), self.image_resolution)
268
- H, W, C = img.shape
269
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
270
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
271
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
272
- self.seed = random.randint(0, 65535)
273
- seed_everything(self.seed)
274
- if self.save_memory:
275
- self.model.low_vram_shift(is_diffusing=False)
276
- cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
277
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
278
- shape = (4, H // 8, W // 8)
279
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
280
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
281
- if self.save_memory:
282
- self.model.low_vram_shift(is_diffusing=False)
283
- x_samples = self.model.decode_first_stage(samples)
284
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
285
  updated_image_path = get_new_image_name(image_path, func_name="canny2image")
286
- real_image = Image.fromarray(x_samples[0]) # get default the index0 image
287
- real_image.save(updated_image_path)
288
- return updated_image_path
289
-
290
- class image2line:
291
- def __init__(self):
292
- print("Direct detect straight line...")
293
- self.detector = MLSDdetector()
294
- self.value_thresh = 0.1
295
- self.dis_thresh = 0.1
296
- self.resolution = 512
297
-
298
- def inference(self, inputs):
299
- print("===>Starting image2hough Inference")
300
- image = Image.open(inputs)
301
- image = np.array(image)
302
- image = HWC3(image)
303
- hough = self.detector(resize_image(image, self.resolution), self.value_thresh, self.dis_thresh)
304
- updated_image_path = get_new_image_name(inputs, func_name="line-of")
305
- hough = 255 - cv2.dilate(hough, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
306
- image = Image.fromarray(hough)
307
- image.save(updated_image_path)
308
- return updated_image_path
309
-
310
-
311
- class line2image:
312
- def __init__(self, device):
313
- print("Initialize the line2image model...")
314
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
315
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_mlsd.pth', location='cpu'))
316
- self.model = model.to(device)
317
- self.device = device
318
- self.ddim_sampler = DDIMSampler(self.model)
319
- self.ddim_steps = 20
320
- self.image_resolution = 512
321
- self.num_samples = 1
322
- self.save_memory = False
323
- self.strength = 1.0
324
- self.guess_mode = False
325
- self.scale = 9.0
326
- self.seed = -1
327
- self.a_prompt = 'best quality, extremely detailed'
328
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
329
-
330
- def inference(self, inputs):
331
- print("===>Starting line2image Inference")
332
- image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
333
- image = Image.open(image_path)
334
- image = np.array(image)
335
- image = 255 - image
336
- prompt = instruct_text
337
- img = resize_image(HWC3(image), self.image_resolution)
338
- H, W, C = img.shape
339
- img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
340
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
341
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
342
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
343
- self.seed = random.randint(0, 65535)
344
- seed_everything(self.seed)
345
- if self.save_memory:
346
- self.model.low_vram_shift(is_diffusing=False)
347
- cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
348
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
349
- shape = (4, H // 8, W // 8)
350
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
351
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
352
- if self.save_memory:
353
- self.model.low_vram_shift(is_diffusing=False)
354
- x_samples = self.model.decode_first_stage(samples)
355
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).\
356
- cpu().numpy().clip(0,255).astype(np.uint8)
357
- updated_image_path = get_new_image_name(image_path, func_name="line2image")
358
- real_image = Image.fromarray(x_samples[0]) # default the index0 image
359
- real_image.save(updated_image_path)
360
- return updated_image_path
361
-
362
-
363
- class image2hed:
364
- def __init__(self):
365
- print("Direct detect soft HED boundary...")
366
- self.detector = HEDdetector()
367
- self.resolution = 512
368
-
369
- def inference(self, inputs):
370
- print("===>Starting image2hed Inference")
371
- image = Image.open(inputs)
372
- image = np.array(image)
373
- image = HWC3(image)
374
- hed = self.detector(resize_image(image, self.resolution))
375
- updated_image_path = get_new_image_name(inputs, func_name="hed-boundary")
376
- image = Image.fromarray(hed)
377
- image.save(updated_image_path)
378
- return updated_image_path
379
-
380
-
381
- class hed2image:
382
- def __init__(self, device):
383
- print("Initialize the hed2image model...")
384
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
385
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_hed.pth', location='cpu'))
386
- self.model = model.to(device)
387
- self.device = device
388
- self.ddim_sampler = DDIMSampler(self.model)
389
- self.ddim_steps = 20
390
- self.image_resolution = 512
391
- self.num_samples = 1
392
- self.save_memory = False
393
- self.strength = 1.0
394
- self.guess_mode = False
395
- self.scale = 9.0
396
- self.seed = -1
397
- self.a_prompt = 'best quality, extremely detailed'
398
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
399
-
400
- def inference(self, inputs):
401
- print("===>Starting hed2image Inference")
402
- image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
403
- image = Image.open(image_path)
404
- image = np.array(image)
405
- prompt = instruct_text
406
- img = resize_image(HWC3(image), self.image_resolution)
407
- H, W, C = img.shape
408
- img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
409
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
410
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
411
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
412
- self.seed = random.randint(0, 65535)
413
- seed_everything(self.seed)
414
- if self.save_memory:
415
- self.model.low_vram_shift(is_diffusing=False)
416
- cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
417
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
418
- shape = (4, H // 8, W // 8)
419
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
420
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
421
- if self.save_memory:
422
- self.model.low_vram_shift(is_diffusing=False)
423
- x_samples = self.model.decode_first_stage(samples)
424
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
425
- updated_image_path = get_new_image_name(image_path, func_name="hed2image")
426
- real_image = Image.fromarray(x_samples[0]) # default the index0 image
427
- real_image.save(updated_image_path)
428
- return updated_image_path
429
-
430
- class image2scribble:
431
- def __init__(self):
432
- print("Direct detect scribble.")
433
- self.detector = HEDdetector()
434
- self.resolution = 512
435
-
436
- def inference(self, inputs):
437
- print("===>Starting image2scribble Inference")
438
- image = Image.open(inputs)
439
- image = np.array(image)
440
- image = HWC3(image)
441
- detected_map = self.detector(resize_image(image, self.resolution))
442
- detected_map = HWC3(detected_map)
443
- image = resize_image(image, self.resolution)
444
- H, W, C = image.shape
445
- detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
446
- detected_map = nms(detected_map, 127, 3.0)
447
- detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
448
- detected_map[detected_map > 4] = 255
449
- detected_map[detected_map < 255] = 0
450
- detected_map = 255 - detected_map
451
- updated_image_path = get_new_image_name(inputs, func_name="scribble")
452
- image = Image.fromarray(detected_map)
453
- image.save(updated_image_path)
454
- return updated_image_path
455
-
456
- class scribble2image:
457
- def __init__(self, device):
458
- print("Initialize the scribble2image model...")
459
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
460
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_scribble.pth', location='cpu'))
461
- self.model = model.to(device)
462
- self.device = device
463
- self.ddim_sampler = DDIMSampler(self.model)
464
- self.ddim_steps = 20
465
- self.image_resolution = 512
466
- self.num_samples = 1
467
- self.save_memory = False
468
- self.strength = 1.0
469
- self.guess_mode = False
470
- self.scale = 9.0
471
- self.seed = -1
472
- self.a_prompt = 'best quality, extremely detailed'
473
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
474
-
475
- def inference(self, inputs):
476
- print("===>Starting scribble2image Inference")
477
- print(f'sketch device {self.device}')
478
- image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
479
- image = Image.open(image_path)
480
- image = np.array(image)
481
- prompt = instruct_text
482
- image = 255 - image
483
- img = resize_image(HWC3(image), self.image_resolution)
484
- H, W, C = img.shape
485
- img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
486
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
487
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
488
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
489
- self.seed = random.randint(0, 65535)
490
- seed_everything(self.seed)
491
- if self.save_memory:
492
- self.model.low_vram_shift(is_diffusing=False)
493
- cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
494
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
495
- shape = (4, H // 8, W // 8)
496
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
497
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
498
- if self.save_memory:
499
- self.model.low_vram_shift(is_diffusing=False)
500
- x_samples = self.model.decode_first_stage(samples)
501
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
502
- updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
503
- real_image = Image.fromarray(x_samples[0]) # default the index0 image
504
- real_image.save(updated_image_path)
505
- return updated_image_path
506
-
507
- class image2pose:
508
- def __init__(self):
509
- print("Direct human pose.")
510
- self.detector = OpenposeDetector()
511
- self.resolution = 512
512
-
513
- def inference(self, inputs):
514
- print("===>Starting image2pose Inference")
515
- image = Image.open(inputs)
516
- image = np.array(image)
517
- image = HWC3(image)
518
- detected_map, _ = self.detector(resize_image(image, self.resolution))
519
- detected_map = HWC3(detected_map)
520
- image = resize_image(image, self.resolution)
521
- H, W, C = image.shape
522
- detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
523
- updated_image_path = get_new_image_name(inputs, func_name="human-pose")
524
- image = Image.fromarray(detected_map)
525
- image.save(updated_image_path)
526
- return updated_image_path
527
-
528
- class pose2image:
529
- def __init__(self, device):
530
- print("Initialize the pose2image model...")
531
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
532
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_openpose.pth', location='cpu'))
533
- self.model = model.to(device)
534
- self.device = device
535
- self.ddim_sampler = DDIMSampler(self.model)
536
- self.ddim_steps = 20
537
- self.image_resolution = 512
538
- self.num_samples = 1
539
- self.save_memory = False
540
- self.strength = 1.0
541
- self.guess_mode = False
542
- self.scale = 9.0
543
- self.seed = -1
544
- self.a_prompt = 'best quality, extremely detailed'
545
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
546
-
547
- def inference(self, inputs):
548
- print("===>Starting pose2image Inference")
549
- image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
550
- image = Image.open(image_path)
551
- image = np.array(image)
552
- prompt = instruct_text
553
- img = resize_image(HWC3(image), self.image_resolution)
554
- H, W, C = img.shape
555
- img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
556
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
557
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
558
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
559
- self.seed = random.randint(0, 65535)
560
- seed_everything(self.seed)
561
- if self.save_memory:
562
- self.model.low_vram_shift(is_diffusing=False)
563
- cond = {"c_concat": [control], "c_crossattn": [ self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
564
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
565
- shape = (4, H // 8, W // 8)
566
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
567
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
568
- if self.save_memory:
569
- self.model.low_vram_shift(is_diffusing=False)
570
- x_samples = self.model.decode_first_stage(samples)
571
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
572
- updated_image_path = get_new_image_name(image_path, func_name="pose2image")
573
- real_image = Image.fromarray(x_samples[0]) # default the index0 image
574
- real_image.save(updated_image_path)
575
- return updated_image_path
576
-
577
- class image2seg:
578
- def __init__(self):
579
- print("Direct segmentations.")
580
- self.detector = UniformerDetector()
581
- self.resolution = 512
582
-
583
- def inference(self, inputs):
584
- print("===>Starting image2seg Inference")
585
- image = Image.open(inputs)
586
- image = np.array(image)
587
- image = HWC3(image)
588
- detected_map = self.detector(resize_image(image, self.resolution))
589
- detected_map = HWC3(detected_map)
590
- image = resize_image(image, self.resolution)
591
- H, W, C = image.shape
592
- detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
593
- updated_image_path = get_new_image_name(inputs, func_name="segmentation")
594
- image = Image.fromarray(detected_map)
595
- image.save(updated_image_path)
596
- return updated_image_path
597
-
598
- class seg2image:
599
- def __init__(self, device):
600
- print("Initialize the seg2image model...")
601
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
602
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_seg.pth', location='cpu'))
603
- self.model = model.to(device)
604
- self.device = device
605
- self.ddim_sampler = DDIMSampler(self.model)
606
- self.ddim_steps = 20
607
- self.image_resolution = 512
608
- self.num_samples = 1
609
- self.save_memory = False
610
- self.strength = 1.0
611
- self.guess_mode = False
612
- self.scale = 9.0
613
- self.seed = -1
614
- self.a_prompt = 'best quality, extremely detailed'
615
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
616
-
617
- def inference(self, inputs):
618
- print("===>Starting seg2image Inference")
619
- image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
620
- image = Image.open(image_path)
621
- image = np.array(image)
622
- prompt = instruct_text
623
- img = resize_image(HWC3(image), self.image_resolution)
624
- H, W, C = img.shape
625
- img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
626
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
627
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
628
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
629
- self.seed = random.randint(0, 65535)
630
- seed_everything(self.seed)
631
- if self.save_memory:
632
- self.model.low_vram_shift(is_diffusing=False)
633
- cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
634
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
635
- shape = (4, H // 8, W // 8)
636
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
637
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
638
- if self.save_memory:
639
- self.model.low_vram_shift(is_diffusing=False)
640
- x_samples = self.model.decode_first_stage(samples)
641
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
642
- updated_image_path = get_new_image_name(image_path, func_name="segment2image")
643
- real_image = Image.fromarray(x_samples[0]) # default the index0 image
644
- real_image.save(updated_image_path)
645
- return updated_image_path
646
-
647
- class image2depth:
648
- def __init__(self):
649
- print("Direct depth estimation.")
650
- self.detector = MidasDetector()
651
- self.resolution = 512
652
-
653
- def inference(self, inputs):
654
- print("===>Starting image2depth Inference")
655
- image = Image.open(inputs)
656
- image = np.array(image)
657
- image = HWC3(image)
658
- detected_map, _ = self.detector(resize_image(image, self.resolution))
659
- detected_map = HWC3(detected_map)
660
- image = resize_image(image, self.resolution)
661
- H, W, C = image.shape
662
- detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
663
- updated_image_path = get_new_image_name(inputs, func_name="depth")
664
- image = Image.fromarray(detected_map)
665
- image.save(updated_image_path)
666
- return updated_image_path
667
-
668
- class depth2image:
669
- def __init__(self, device):
670
- print("Initialize depth2image model...")
671
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
672
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_depth.pth', location='cpu'))
673
- self.model = model.to(device)
674
- self.device = device
675
- self.ddim_sampler = DDIMSampler(self.model)
676
- self.ddim_steps = 20
677
- self.image_resolution = 512
678
- self.num_samples = 1
679
- self.save_memory = False
680
- self.strength = 1.0
681
- self.guess_mode = False
682
- self.scale = 9.0
683
- self.seed = -1
684
- self.a_prompt = 'best quality, extremely detailed'
685
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
686
-
687
- def inference(self, inputs):
688
- print("===>Starting depth2image Inference")
689
- image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
690
- image = Image.open(image_path)
691
- image = np.array(image)
692
- prompt = instruct_text
693
- img = resize_image(HWC3(image), self.image_resolution)
694
- H, W, C = img.shape
695
- img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
696
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
697
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
698
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
699
- self.seed = random.randint(0, 65535)
700
- seed_everything(self.seed)
701
- if self.save_memory:
702
- self.model.low_vram_shift(is_diffusing=False)
703
- cond = {"c_concat": [control], "c_crossattn": [ self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
704
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
705
- shape = (4, H // 8, W // 8)
706
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
707
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
708
- if self.save_memory:
709
- self.model.low_vram_shift(is_diffusing=False)
710
- x_samples = self.model.decode_first_stage(samples)
711
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
712
- updated_image_path = get_new_image_name(image_path, func_name="depth2image")
713
- real_image = Image.fromarray(x_samples[0]) # default the index0 image
714
- real_image.save(updated_image_path)
715
- return updated_image_path
716
-
717
- class image2normal:
718
- def __init__(self):
719
- print("Direct normal estimation.")
720
- self.detector = MidasDetector()
721
- self.resolution = 512
722
- self.bg_threshold = 0.4
723
-
724
- def inference(self, inputs):
725
- print("===>Starting image2 normal Inference")
726
- image = Image.open(inputs)
727
- image = np.array(image)
728
- image = HWC3(image)
729
- _, detected_map = self.detector(resize_image(image, self.resolution), bg_th=self.bg_threshold)
730
- detected_map = HWC3(detected_map)
731
- image = resize_image(image, self.resolution)
732
- H, W, C = image.shape
733
- detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
734
- updated_image_path = get_new_image_name(inputs, func_name="normal-map")
735
- image = Image.fromarray(detected_map)
736
- image.save(updated_image_path)
737
- return updated_image_path
738
-
739
- class normal2image:
740
- def __init__(self, device):
741
- print("Initialize normal2image model...")
742
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
743
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_normal.pth', location='cpu'))
744
- self.model = model.to(device)
745
- self.device = device
746
- self.ddim_sampler = DDIMSampler(self.model)
747
- self.ddim_steps = 20
748
- self.image_resolution = 512
749
- self.num_samples = 1
750
- self.save_memory = False
751
- self.strength = 1.0
752
- self.guess_mode = False
753
- self.scale = 9.0
754
- self.seed = -1
755
- self.a_prompt = 'best quality, extremely detailed'
756
- self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
757
-
758
- def inference(self, inputs):
759
- print("===>Starting normal2image Inference")
760
- image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
761
- image = Image.open(image_path)
762
- image = np.array(image)
763
- prompt = instruct_text
764
- img = image[:, :, ::-1].copy()
765
- img = resize_image(HWC3(img), self.image_resolution)
766
- H, W, C = img.shape
767
- img = cv2.resize(img, (W, H), interpolation=cv2.INTER_NEAREST)
768
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
769
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
770
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
771
- self.seed = random.randint(0, 65535)
772
- seed_everything(self.seed)
773
- if self.save_memory:
774
- self.model.low_vram_shift(is_diffusing=False)
775
- cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
776
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
777
- shape = (4, H // 8, W // 8)
778
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13)
779
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
780
- if self.save_memory:
781
- self.model.low_vram_shift(is_diffusing=False)
782
- x_samples = self.model.decode_first_stage(samples)
783
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
784
- updated_image_path = get_new_image_name(image_path, func_name="normal2image")
785
- real_image = Image.fromarray(x_samples[0]) # default the index0 image
786
  real_image.save(updated_image_path)
787
  return updated_image_path
788
 
@@ -961,4 +452,4 @@ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
961
  clear.click(bot.memory.clear)
962
  clear.click(lambda: [], None, chatbot)
963
  clear.click(lambda: [], None, state)
964
- demo.launch()
 
20
  import re
21
  import uuid
22
  from diffusers import StableDiffusionInpaintPipeline
23
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
24
+ from diffusers import UniPCMultistepScheduler
25
  from PIL import Image
26
  import numpy as np
27
  from omegaconf import OmegaConf
 
30
  import einops
31
  from pytorch_lightning import seed_everything
32
  import random
 
 
 
 
 
 
 
 
 
 
33
 
34
  VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
35
 
 
215
  class image2canny:
216
  def __init__(self):
217
  print("Direct detect canny.")
 
218
  self.low_thresh = 100
219
  self.high_thresh = 200
220
 
 
222
  print("===>Starting image2canny Inference")
223
  image = Image.open(inputs)
224
  image = np.array(image)
225
+
226
+ image = cv2.Canny(image, low_threshold, high_threshold)
227
+ image = image[:, :, None]
228
+ image = np.concatenate([image, image, image], axis=2)
229
+ canny_image = Image.fromarray(image)
230
  updated_image_path = get_new_image_name(inputs, func_name="edge")
231
+ canny_image.save(updated_image_path)
232
  return updated_image_path
233
 
234
  class canny2image:
235
  def __init__(self, device):
236
  print("Initialize the canny2image model.")
237
+ low_threshold = 100
238
+ high_threshold = 200
239
+
240
+ # Models
241
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
242
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
243
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
244
+ )
245
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
246
+
247
+ # This command loads the individual model components on GPU on-demand. So, we don't
248
+ # need to explicitly call pipe.to("cuda").
249
+ self.pipe.enable_model_cpu_offload()
250
+
251
+ self.pipe.enable_xformers_memory_efficient_attention()
252
+
253
+ # Generator seed,
254
+ self.generator = torch.manual_seed(0)
255
+
256
+
257
+ def get_canny_filter(self,image):
258
+ if not isinstance(image, np.ndarray):
259
+ image = np.array(image)
260
+ image = cv2.Canny(image, low_threshold, high_threshold)
261
+ image = image[:, :, None]
262
+ image = np.concatenate([image, image, image], axis=2)
263
+ canny_image = Image.fromarray(image)
264
+ return canny_image
265
+
266
  def inference(self, inputs):
267
  print("===>Starting canny2image Inference")
268
  image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
269
  image = Image.open(image_path)
270
  image = np.array(image)
 
271
  prompt = instruct_text
272
+ canny_image = self.get_canny_filter(image)
273
+ output = self.pipe(prompt,canny_image,generator=self.generator,num_images_per_prompt=1,num_inference_steps=20)
274
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  updated_image_path = get_new_image_name(image_path, func_name="canny2image")
276
+ real_image = Image.fromarray(output.images[0]) # get default the index0 image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  real_image.save(updated_image_path)
278
  return updated_image_path
279
 
 
452
  clear.click(bot.memory.clear)
453
  clear.click(lambda: [], None, chatbot)
454
  clear.click(lambda: [], None, state)
455
+ demo.launch()