nuwandaa commited on
Commit
18bffde
·
1 Parent(s): 39a774c
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -106,18 +106,19 @@ def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, s
106
  AAS_end_layer=END_LAYER, # AAS end layer
107
  num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
108
  generator=generator,
109
- guidance_scale=1
 
110
  ).images[0]
111
  print('Inferece: DONE.')
112
 
113
- # pil_mask = to_pil_image(mask.squeeze(0))
114
- # pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
115
- # mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device)
116
- # mask_f = 1-(1 - mask) * (1 - mask_blurred)
117
 
118
- # image_1 = image.unsqueeze(0)
119
 
120
- return source_image, image
121
  except:
122
  print(traceback.format_exc())
123
 
@@ -165,11 +166,7 @@ with gr.Blocks() as demo:
165
  with gr.Column():
166
  run_button = gr.Button("Generate")
167
 
168
- result = ImageSlider(
169
- interactive=False,
170
- label="Generated Image",
171
- type="pil"
172
- )
173
 
174
  run_button.click(
175
  fn=remove,
 
106
  AAS_end_layer=END_LAYER, # AAS end layer
107
  num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
108
  generator=generator,
109
+ guidance_scale=1,
110
+ output_type='pt'
111
  ).images[0]
112
  print('Inferece: DONE.')
113
 
114
+ pil_mask = to_pil_image(mask.squeeze(0))
115
+ pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
116
+ mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device)
117
+ mask_f = 1-(1 - mask) * (1 - mask_blurred)
118
 
119
+ image_1 = image.unsqueeze(0)
120
 
121
+ return source_image_pure, pil_mask, image_1
122
  except:
123
  print(traceback.format_exc())
124
 
 
166
  with gr.Column():
167
  run_button = gr.Button("Generate")
168
 
169
+ result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
 
 
 
 
170
 
171
  run_button.click(
172
  fn=remove,