Alexander McKinney commited on
Commit
20ddfe8
·
1 Parent(s): 92ba1f6

adds stable diffusion 2, attention slicing, cuda masking

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -6,16 +6,24 @@ import os
6
  from PIL import Image
7
  from typing import List, Optional
8
  from functools import reduce
 
9
 
10
  import gradio as gr
11
 
12
  from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig
13
  from transformers.models.detr.feature_extraction_detr import rgb_to_id
14
 
15
- from diffusers import StableDiffusionInpaintPipeline
 
 
 
 
 
 
 
16
 
17
  auth_token = os.environ.get("READ_TOKEN")
18
- try_cuda = True
19
 
20
  torch.inference_mode()
21
  torch.no_grad()
@@ -29,7 +37,7 @@ def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic
29
  return feature_extractor, model, cfg
30
 
31
  # Load diffusion pipeline
32
- def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
33
  return StableDiffusionInpaintPipeline.from_pretrained(
34
  model_name,
35
  revision='fp16',
@@ -51,10 +59,10 @@ def max_pool(x: torch.Tensor, kernel_size: int):
51
 
52
  # Apply min-max pooling to clean up mask
53
  def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
54
- mask = torch.Tensor(mask[None, None]).float()
55
  mask = min_pool(mask, min_kernel)
56
  mask = max_pool(mask, max_kernel)
57
- mask = mask.bool().squeeze().numpy()
58
  return mask
59
 
60
 
@@ -62,11 +70,14 @@ feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_mode
62
  pipe = load_diffusion_pipeline()
63
 
64
  device = get_device(try_cuda=try_cuda)
 
65
  pipe = pipe.to(device)
 
 
66
 
67
  # Callback function that runs segmentation and updates CheckboxGroup
68
  def fn_segmentation(image, max_kernel, min_kernel):
69
- inputs = feature_extractor(images=image, return_tensors="pt")
70
  outputs = segmentation_model(**inputs)
71
 
72
  processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
 
6
  from PIL import Image
7
  from typing import List, Optional
8
  from functools import reduce
9
+ from argparse import ArgumentParser
10
 
11
  import gradio as gr
12
 
13
  from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig
14
  from transformers.models.detr.feature_extraction_detr import rgb_to_id
15
 
16
+ from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler
17
+
18
+ # TODO: xformers install for faster diffusion
19
+
20
+ parser = ArgumentParser()
21
+ parser.add_argument('--disable-cuda', action='store_true')
22
+ parser.add_argument('--attention-slicing', action='store_true')
23
+ args = parser.parse_args()
24
 
25
  auth_token = os.environ.get("READ_TOKEN")
26
+ try_cuda = not args.disable_cuda
27
 
28
  torch.inference_mode()
29
  torch.no_grad()
 
37
  return feature_extractor, model, cfg
38
 
39
  # Load diffusion pipeline
40
+ def load_diffusion_pipeline(model_name: str = 'stabilityai/stable-diffusion-2-inpainting'):
41
  return StableDiffusionInpaintPipeline.from_pretrained(
42
  model_name,
43
  revision='fp16',
 
59
 
60
  # Apply min-max pooling to clean up mask
61
  def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
62
+ mask = torch.Tensor(mask[None, None]).float().to(device)
63
  mask = min_pool(mask, min_kernel)
64
  mask = max_pool(mask, max_kernel)
65
+ mask = mask.bool().squeeze().cpu().numpy()
66
  return mask
67
 
68
 
 
70
  pipe = load_diffusion_pipeline()
71
 
72
  device = get_device(try_cuda=try_cuda)
73
+ segmentation_model = segmentation_model.to(device)
74
  pipe = pipe.to(device)
75
+ if args.attention_slicing:
76
+ pipe.enable_attention_slicing()
77
 
78
  # Callback function that runs segmentation and updates CheckboxGroup
79
  def fn_segmentation(image, max_kernel, min_kernel):
80
+ inputs = feature_extractor(images=image, return_tensors="pt").to(device)
81
  outputs = segmentation_model(**inputs)
82
 
83
  processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)