nuwandaa commited on
Commit
19a16d0
·
1 Parent(s): 18bffde

Update flow

Browse files
Files changed (1) hide show
  1. app.py +37 -25
app.py CHANGED
@@ -8,10 +8,9 @@ from diffusers.utils import load_image
8
  from torchvision.transforms.functional import to_tensor, gaussian_blur
9
  from matplotlib import pyplot as plt
10
  import gradio as gr
11
- import spaces
12
  from gradio_imageslider import ImageSlider
13
  from torchvision.transforms.functional import to_pil_image, to_tensor
14
- from PIL import ImageFilter
15
  import traceback
16
 
17
 
@@ -22,19 +21,34 @@ def preprocess_image(input_image, device):
22
  image = image.expand(-1, 3, -1, -1)
23
  image = F.interpolate(image, (1024, 1024))
24
  image = image.to(dtype).to(device)
25
-
26
  return image
27
 
28
 
29
  def preprocess_mask(input_mask, device):
30
- mask = to_tensor(input_mask.convert('L'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  mask = mask.unsqueeze_(0).float() # 0 or 1
32
  mask = F.interpolate(mask, (1024, 1024))
33
  mask = gaussian_blur(mask, kernel_size=(77, 77))
34
  mask[mask < 0.1] = 0
35
  mask[mask >= 0.1] = 1
36
  mask = mask.to(dtype).to(device)
37
-
38
  return mask
39
 
40
 
@@ -42,7 +56,7 @@ def make_redder(img, mask, increase_factor=0.4):
42
  img_redder = img.clone()
43
  mask_expanded = mask.expand_as(img)
44
  img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1)
45
-
46
  return img_redder
47
 
48
 
@@ -67,29 +81,28 @@ pipeline = DiffusionPipeline.from_pretrained(
67
 
68
  if is_attention_slicing_enabled:
69
  pipeline.enable_attention_slicing()
70
-
71
  if is_cpu_offload_enabled:
72
  pipeline.enable_model_cpu_offload()
73
 
74
 
75
- @spaces.GPU
76
  def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8):
77
  try:
78
- generator = torch.Generator(device).manual_seed(seed)
79
  prompt = "" # Set prompt to null
80
 
81
  source_image_pure = gradio_image["background"]
82
  mask_image_pure = gradio_image["layers"][0]
83
  source_image = preprocess_image(source_image_pure.convert('RGB'), device)
84
  mask = preprocess_mask(mask_image_pure, device)
85
-
86
  START_STEP = 0 # AAS start step
87
  END_STEP = int(strength * num_inference_steps) # AAS end step
88
- LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer
89
  END_LAYER = 70 # AAS end layer
90
  ss_steps = 9 # similarity suppression steps
91
  ss_scale = 0.3 # similarity suppression scale
92
-
93
  image = pipeline(
94
  prompt=prompt,
95
  image=source_image,
@@ -102,26 +115,25 @@ def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, s
102
  ss_steps = ss_steps, # similarity suppression steps
103
  ss_scale = ss_scale, # similarity suppression scale
104
  AAS_start_step=START_STEP, # AAS start step
105
- AAS_start_layer=LAYER, # AAS start layer
106
  AAS_end_layer=END_LAYER, # AAS end layer
107
  num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
108
  generator=generator,
109
- guidance_scale=1,
110
- output_type='pt'
111
  ).images[0]
112
  print('Inferece: DONE.')
113
-
114
  pil_mask = to_pil_image(mask.squeeze(0))
115
  pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
116
  mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device)
117
  mask_f = 1-(1 - mask) * (1 - mask_blurred)
118
-
119
- image_1 = image.unsqueeze(0)
120
-
121
- return source_image_pure, pil_mask, image_1
122
  except:
123
  print(traceback.format_exc())
124
-
125
 
126
  title = """<h1 align="center">Object Remove</h1>"""
127
  with gr.Blocks() as demo:
@@ -157,7 +169,7 @@ with gr.Blocks() as demo:
157
  step=0.1,
158
  label="Strength"
159
  )
160
-
161
  input_image = gr.ImageMask(
162
  type="pil", label="Input Image",crop_size=(1200,1200), layers=False
163
  )
@@ -167,11 +179,11 @@ with gr.Blocks() as demo:
167
  run_button = gr.Button("Generate")
168
 
169
  result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
170
-
171
  run_button.click(
172
  fn=remove,
173
  inputs=[input_image, guidance_scale, num_steps, seed, strength],
174
  outputs=result,
175
  )
176
-
177
- demo.queue(max_size=12).launch(share=False)
 
8
  from torchvision.transforms.functional import to_tensor, gaussian_blur
9
  from matplotlib import pyplot as plt
10
  import gradio as gr
 
11
  from gradio_imageslider import ImageSlider
12
  from torchvision.transforms.functional import to_pil_image, to_tensor
13
+ from PIL import ImageFilter, Image
14
  import traceback
15
 
16
 
 
21
  image = image.expand(-1, 3, -1, -1)
22
  image = F.interpolate(image, (1024, 1024))
23
  image = image.to(dtype).to(device)
24
+
25
  return image
26
 
27
 
28
  def preprocess_mask(input_mask, device):
29
+ # Split the channels
30
+ r, g, b, alpha = input_mask.split()
31
+
32
+ # Create a new image where:
33
+ # - Black areas (where RGB = 0) become white (255).
34
+ # - Transparent areas (where alpha = 0) become black (0).
35
+ new_mask = Image.new("L", input_mask.size)
36
+
37
+ for x in range(input_mask.width):
38
+ for y in range(input_mask.height):
39
+ if alpha.getpixel((x, y)) == 0: # Transparent pixel
40
+ new_mask.putpixel((x, y), 0) # Set to black
41
+ else: # Non-transparent pixel (originally black in the mask)
42
+ new_mask.putpixel((x, y), 255) # Set to white
43
+
44
+ mask = to_tensor(new_mask.convert('L'))
45
  mask = mask.unsqueeze_(0).float() # 0 or 1
46
  mask = F.interpolate(mask, (1024, 1024))
47
  mask = gaussian_blur(mask, kernel_size=(77, 77))
48
  mask[mask < 0.1] = 0
49
  mask[mask >= 0.1] = 1
50
  mask = mask.to(dtype).to(device)
51
+
52
  return mask
53
 
54
 
 
56
  img_redder = img.clone()
57
  mask_expanded = mask.expand_as(img)
58
  img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1)
59
+
60
  return img_redder
61
 
62
 
 
81
 
82
  if is_attention_slicing_enabled:
83
  pipeline.enable_attention_slicing()
84
+
85
  if is_cpu_offload_enabled:
86
  pipeline.enable_model_cpu_offload()
87
 
88
 
 
89
  def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8):
90
  try:
91
+ generator = torch.Generator('cuda').manual_seed(seed)
92
  prompt = "" # Set prompt to null
93
 
94
  source_image_pure = gradio_image["background"]
95
  mask_image_pure = gradio_image["layers"][0]
96
  source_image = preprocess_image(source_image_pure.convert('RGB'), device)
97
  mask = preprocess_mask(mask_image_pure, device)
98
+
99
  START_STEP = 0 # AAS start step
100
  END_STEP = int(strength * num_inference_steps) # AAS end step
101
+ LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer
102
  END_LAYER = 70 # AAS end layer
103
  ss_steps = 9 # similarity suppression steps
104
  ss_scale = 0.3 # similarity suppression scale
105
+
106
  image = pipeline(
107
  prompt=prompt,
108
  image=source_image,
 
115
  ss_steps = ss_steps, # similarity suppression steps
116
  ss_scale = ss_scale, # similarity suppression scale
117
  AAS_start_step=START_STEP, # AAS start step
118
+ AAS_start_layer=LAYER, # AAS start layer
119
  AAS_end_layer=END_LAYER, # AAS end layer
120
  num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
121
  generator=generator,
122
+ guidance_scale=1
 
123
  ).images[0]
124
  print('Inferece: DONE.')
125
+
126
  pil_mask = to_pil_image(mask.squeeze(0))
127
  pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
128
  mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device)
129
  mask_f = 1-(1 - mask) * (1 - mask_blurred)
130
+
131
+ # image_1 = image.unsqueeze(0)
132
+
133
+ return source_image_pure, pil_mask, image
134
  except:
135
  print(traceback.format_exc())
136
+
137
 
138
  title = """<h1 align="center">Object Remove</h1>"""
139
  with gr.Blocks() as demo:
 
169
  step=0.1,
170
  label="Strength"
171
  )
172
+
173
  input_image = gr.ImageMask(
174
  type="pil", label="Input Image",crop_size=(1200,1200), layers=False
175
  )
 
179
  run_button = gr.Button("Generate")
180
 
181
  result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
182
+
183
  run_button.click(
184
  fn=remove,
185
  inputs=[input_image, guidance_scale, num_steps, seed, strength],
186
  outputs=result,
187
  )
188
+
189
+ demo.queue(max_size=12).launch(share=True)