chaojiemao commited on
Commit
dd72440
·
2 Parent(s): d1a539d f4d0287

wqMerge branch 'main' of https://huggingface.co/spaces/scepter-studio/ACE-Plus into main

Browse files
Files changed (2) hide show
  1. app.py +195 -184
  2. inference/__init__.py +1 -1
app.py CHANGED
@@ -13,19 +13,15 @@ import shlex
13
  import subprocess
14
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
15
  subprocess.run(shlex.split('pip install scepter'))
 
16
  from scepter.modules.transform.io import pillow_convert
17
  from scepter.modules.utils.config import Config
18
  from scepter.modules.utils.distribute import we
19
  from scepter.modules.utils.file_system import FS
20
-
21
- from inference.ace_plus_diffusers import ACEPlusDiffuserInference
22
  from inference.utils import edit_preprocess
23
- from examples.examples import all_examples
24
-
25
 
26
- inference_dict = {
27
- "ACE_DIFFUSER_PLUS": ACEPlusDiffuserInference
28
- }
29
 
30
  fs_list = [
31
  Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
@@ -38,15 +34,10 @@ for one_fs in fs_list:
38
  FS.init_fs_client(one_fs)
39
 
40
  os.environ["FLUX_FILL_PATH"]="hf://black-forest-labs/FLUX.1-Fill-dev"
41
- os.environ["PORTRAIT_MODEL_PATH"]="hf://ali-vilab/ACE_Plus@portrait/comfyui_portrait_lora64.safetensors"
42
- os.environ["SUBJECT_MODEL_PATH"]="hf://ali-vilab/ACE_Plus@subject/comfyui_subject_lora16.safetensors"
43
- os.environ["LOCAL_MODEL_PATH"]="hf://ali-vilab/ACE_Plus@local_editing/comfyui_local_lora16.safetensors"
44
 
45
  FS.get_dir_to_local_dir(os.environ["FLUX_FILL_PATH"])
46
- FS.get_from(os.environ["PORTRAIT_MODEL_PATH"])
47
- FS.get_from(os.environ["SUBJECT_MODEL_PATH"])
48
- FS.get_from(os.environ["LOCAL_MODEL_PATH"])
49
-
50
 
51
  csv.field_size_limit(sys.maxsize)
52
  refresh_sty = '\U0001f504' # 🔄
@@ -60,51 +51,39 @@ lock = threading.Lock()
60
  class DemoUI(object):
61
  #@spaces.GPU(duration=60)
62
  def __init__(self,
63
- infer_dir = "./config",
64
- model_list='./models/model_zoo.yaml'
65
  ):
66
- self.model_yamls = glob.glob(os.path.join(infer_dir,
67
- '*.yaml'))
68
  self.model_choices = dict()
69
  self.default_model_name = ''
 
 
 
70
  for i in self.model_yamls:
71
  model_cfg = Config(load=True, cfg_file=i)
72
- model_name = model_cfg.NAME
73
  if model_cfg.IS_DEFAULT: self.default_model_name = model_name
74
  self.model_choices[model_name] = model_cfg
 
 
 
 
 
75
  print('Models: ', self.model_choices.keys())
76
  assert len(self.model_choices) > 0
77
  if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
78
  self.model_name = self.default_model_name
79
  pipe_cfg = self.model_choices[self.default_model_name]
80
- infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
81
- self.pipe = inference_dict[infer_name]()
82
- self.pipe.init_from_cfg(pipe_cfg)
83
-
84
- # choose different model
85
- self.task_model_cfg = Config(load=True, cfg_file=model_list)
86
- self.task_model = {}
87
- self.task_model_list = []
88
- self.edit_type_dict = {"repainting": None}
89
- self.edit_type_list = ["repainting"]
90
- for task_name, task_model in self.task_model_cfg.MODEL.items():
91
- self.task_model[task_name.lower()] = task_model
92
- self.task_model_list.append(task_name.lower())
93
- for preprocessor in task_model.get("PREPROCESSOR", []):
94
- if preprocessor["TYPE"] in self.edit_type_dict:
95
- continue
96
- preprocessor["REPAINTING_SCALE"] = task_model.get("REPAINTING_SCALE", 1.0)
97
- self.edit_type_dict[preprocessor["TYPE"]] = preprocessor
98
- self.max_msgs = 20
99
  # reformat examples
100
  self.all_examples = [
101
  [
102
- one_example["task_type"], one_example["edit_type"], one_example["instruction"],
103
- one_example["input_reference_image"], one_example["input_image"],
104
- one_example["input_mask"], one_example["output_h"],
105
- one_example["output_w"], one_example["seed"]
106
- ]
107
- for one_example in all_examples
108
  ]
109
 
110
  def construct_edit_image(self, edit_image, edit_mask):
@@ -127,9 +106,6 @@ class DemoUI(object):
127
  else:
128
  return None
129
 
130
-
131
-
132
-
133
  def create_ui(self):
134
  with gr.Row(equal_height=True, visible=True):
135
  with gr.Column(scale=2):
@@ -146,40 +122,102 @@ class DemoUI(object):
146
  height=600,
147
  interactive=False,
148
  type='pil',
149
- elem_id='preprocess_image'
 
150
  )
151
 
152
  self.edit_preprocess_mask_preview = gr.Image(
153
  height=600,
154
  interactive=False,
155
  type='pil',
156
- elem_id='preprocess_image_mask'
 
 
 
 
 
 
 
 
 
157
  )
158
  with gr.Row():
159
  instruction = """
160
  **Instruction**:
161
- 1. Please choose the Task Type based on the scenario of the generation task. We provide three types of generation capabilities: Portrait ID Preservation Generation(portrait),
162
- Object ID Preservation Generation(subject), and Local Controlled Generation(local editing), which can be selected from the task dropdown menu.
163
- 2. When uploading images in the Reference Image section, the generated image will reference the ID information of that image. Please ensure that the ID information is clear.
164
- In the Edit Image section, the uploaded image will maintain its structural and content information, and you must draw a mask area to specify the region to be regenerated.
165
- 3. When the task type is local editing, there are various editing types to choose from. Users can select different information preserving dimensions, such as edge information,
166
- color information, and more. The pre-processing information can be viewed in the 'related input image' tab.
167
- 4. More details can be found in [page](https://ali-vilab.github.io/ACE_plus_page).
168
  """
169
  self.instruction = gr.Markdown(value=instruction)
 
 
 
 
 
 
 
 
 
170
  with gr.Row():
171
  self.model_name_dd = gr.Dropdown(
172
  choices=self.model_choices,
173
  value=self.default_model_name,
174
  label='Model Version')
175
- self.task_type = gr.Dropdown(choices=self.task_model_list,
176
  interactive=True,
177
- value=self.task_model_list[0],
178
- label='Task Type')
179
- self.edit_type = gr.Dropdown(choices=self.edit_type_list,
180
- interactive=True,
181
- value=self.edit_type_list[0],
182
  label='Edit Type')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  with gr.Row():
184
  self.generation_info_preview = gr.Markdown(
185
  label='System Log.',
@@ -192,11 +230,11 @@ class DemoUI(object):
192
  placeholder='Input "@" find history of image',
193
  label='Instruction',
194
  container=False,
195
- lines = 1)
196
  with gr.Column(scale=2, min_width=100):
197
  with gr.Row():
198
  with gr.Column(scale=1, min_width=100):
199
- self.chat_btn = gr.Button(value='Generate', variant = "primary")
200
 
201
  with gr.Accordion(label='Advance', open=True):
202
  with gr.Row(visible=True):
@@ -223,45 +261,8 @@ class DemoUI(object):
223
  format="png"
224
  )
225
 
226
- with gr.Row():
227
- self.step = gr.Slider(minimum=1,
228
- maximum=1000,
229
- value=self.pipe.input.get("sample_steps", 20),
230
- visible=self.pipe.input.get("sample_steps", None) is not None,
231
- label='Sample Step')
232
- self.cfg_scale = gr.Slider(
233
- minimum=1.0,
234
- maximum=100.0,
235
- value=self.pipe.input.get("guide_scale", 4.5),
236
- visible=self.pipe.input.get("guide_scale", None) is not None,
237
- label='Guidance Scale')
238
- self.seed = gr.Slider(minimum=-1,
239
- maximum=10000000,
240
- value=-1,
241
- label='Seed')
242
- self.output_height = gr.Slider(
243
- minimum=256,
244
- maximum=1440,
245
- value=self.pipe.input.get("output_height", 1024),
246
- visible=self.pipe.input.get("output_height", None) is not None,
247
- label='Output Height')
248
- self.output_width = gr.Slider(
249
- minimum=256,
250
- maximum=1440,
251
- value=self.pipe.input.get("output_width", 1024),
252
- visible=self.pipe.input.get("output_width", None) is not None,
253
- label='Output Width')
254
-
255
- self.repainting_scale = gr.Slider(
256
- minimum=0.0,
257
- maximum=1.0,
258
- value=self.pipe.input.get("repainting_scale", 1.0),
259
- visible=True,
260
- label='Repainting Scale')
261
- with gr.Row():
262
- self.eg = gr.Column(visible=True)
263
-
264
-
265
 
266
  def set_callbacks(self, *args, **kwargs):
267
  ########################################
@@ -276,25 +277,23 @@ class DemoUI(object):
276
  torch.cuda.empty_cache()
277
  torch.cuda.ipc_collect()
278
  pipe_cfg = self.model_choices[model_name]
279
- infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
280
- self.pipe = inference_dict[infer_name]()
281
- self.pipe.init_from_cfg(pipe_cfg)
282
  self.model_name = model_name
283
  lock.release()
284
 
285
  return (model_name, gr.update(),
286
  gr.Slider(
287
- value=self.pipe.input.get("sample_steps", 20),
288
- visible=self.pipe.input.get("sample_steps", None) is not None),
289
  gr.Slider(
290
  value=self.pipe.input.get("guide_scale", 4.5),
291
  visible=self.pipe.input.get("guide_scale", None) is not None),
292
  gr.Slider(
293
- value=self.pipe.input.get("output_height", 1024),
294
- visible=self.pipe.input.get("output_height", None) is not None),
295
  gr.Slider(
296
- value=self.pipe.input.get("output_width", 1024),
297
- visible=self.pipe.input.get("output_width", None) is not None),
298
  gr.Slider(value=self.pipe.input.get("repainting_scale", 1.0))
299
  )
300
 
@@ -309,31 +308,21 @@ class DemoUI(object):
309
  self.output_width,
310
  self.repainting_scale])
311
 
312
- def change_task_type(task_type):
313
- task_info = self.task_model[task_type]
314
- edit_type_list = [self.edit_type_list[0]]
315
- for preprocessor in task_info.get("PREPROCESSOR", []):
316
- preprocessor["REPAINTING_SCALE"] = task_info.get("REPAINTING_SCALE", 1.0)
317
- self.edit_type_dict[preprocessor["TYPE"]] = preprocessor
318
- edit_type_list.append(preprocessor["TYPE"])
319
-
320
- return gr.update(choices=edit_type_list, value=edit_type_list[0])
321
-
322
- self.task_type.change(change_task_type, inputs=[self.task_type], outputs=[self.edit_type])
323
-
324
  def change_edit_type(edit_type):
325
  edit_info = self.edit_type_dict[edit_type]
326
  edit_info = edit_info or {}
327
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
328
- if edit_type == self.edit_type_list[0]:
329
- return gr.Slider(value=1.0)
330
- else:
331
- return gr.Slider(
332
- value=repainting_scale)
333
 
334
  self.edit_type.change(change_edit_type, inputs=[self.edit_type], outputs=[self.repainting_scale])
335
 
336
- def preprocess_input(ref_image, edit_image_dict, preprocess = None):
 
 
 
 
 
 
337
  err_msg = ""
338
  is_suc = True
339
  if ref_image is not None:
@@ -349,8 +338,9 @@ class DemoUI(object):
349
  edit_image = None
350
  edit_mask = None
351
  elif np.sum(np.array(edit_mask)) < 1:
352
- err_msg = "You must draw the repainting area for the edited image."
353
- return None, None, None, False, err_msg
 
354
  else:
355
  edit_image = pillow_convert(edit_image, "RGB")
356
  edit_mask = Image.fromarray(edit_mask).convert('L')
@@ -358,43 +348,38 @@ class DemoUI(object):
358
  err_msg = "Please provide the reference image or edited image."
359
  return None, None, None, False, err_msg
360
  return edit_image, edit_mask, ref_image, is_suc, err_msg
 
361
  @spaces.GPU(duration=80)
362
  def run_chat(
363
- prompt,
364
- ref_image,
365
- edit_image,
366
- task_type,
367
- edit_type,
368
- cfg_scale,
369
- step,
370
- seed,
371
- output_h,
372
- output_w,
373
- repainting_scale,
374
- progress=gr.Progress(track_tqdm=True)
 
 
375
  ):
376
- print(prompt)
377
- model_path = self.task_model[task_type]["MODEL_PATH"]
378
  edit_info = self.edit_type_dict[edit_type]
379
-
380
- if task_type in ["portrait", "subject"] and ref_image is None:
381
- err_msg = "<mark>Please provide the reference image.</mark>"
382
- return (gr.Image(), gr.Column(visible=True),
383
- gr.Image(),
384
- gr.Image(),
385
- gr.Text(value=err_msg))
386
-
387
  pre_edit_image, pre_edit_mask, pre_ref_image, is_suc, err_msg = preprocess_input(ref_image, edit_image)
 
388
  if not is_suc:
389
  err_msg = f"<mark>{err_msg}</mark>"
390
  return (gr.Image(), gr.Column(visible=True),
 
391
  gr.Image(),
392
  gr.Image(),
393
  gr.Text(value=err_msg))
394
- pre_edit_image = edit_preprocess(edit_info, we.device_id, pre_edit_image, pre_edit_mask)
395
  # edit_image["background"] = pre_edit_image
396
  st = time.time()
397
- image, seed = self.pipe(
398
  reference_image=pre_ref_image,
399
  edit_image=pre_edit_image,
400
  edit_mask=pre_edit_mask,
@@ -406,32 +391,43 @@ class DemoUI(object):
406
  guide_scale=cfg_scale,
407
  seed=seed,
408
  repainting_scale=repainting_scale,
409
- lora_path = model_path
 
 
410
  )
411
  et = time.time()
412
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
413
 
 
 
 
414
  return (gr.Image(value=image), gr.Column(visible=True),
415
- gr.Image(value=pre_edit_image if pre_edit_image is not None else pre_ref_image),
 
416
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
417
- gr.Text(value=msg))
 
418
 
419
  chat_inputs = [
420
  self.reference_image,
421
  self.edit_image,
422
- self.task_type,
423
  self.edit_type,
424
  self.cfg_scale,
425
  self.step,
426
  self.seed,
427
  self.output_height,
428
  self.output_width,
429
- self.repainting_scale
 
 
 
430
  ]
431
 
432
  chat_outputs = [
433
- self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
434
- self.edit_preprocess_mask_preview, self.generation_info_preview
 
 
435
  ]
436
 
437
  self.chat_btn.click(run_chat,
@@ -445,23 +441,26 @@ class DemoUI(object):
445
  queue=True)
446
 
447
  @spaces.GPU(duration=80)
448
- def run_example(task_type, edit_type, prompt, ref_image, edit_image, edit_mask,
449
- output_h, output_w, seed, progress=gr.Progress(track_tqdm=True)):
450
- model_path = self.task_model[task_type]["MODEL_PATH"]
 
451
 
452
  step = self.pipe.input.get("sample_steps", 20)
453
  cfg_scale = self.pipe.input.get("guide_scale", 20)
454
-
455
  edit_info = self.edit_type_dict[edit_type]
456
 
457
  edit_image = self.construct_edit_image(edit_image, edit_mask)
458
 
459
- pre_edit_image, pre_edit_mask, pre_ref_image, is_suc, err_msg = preprocess_input(ref_image, edit_image)
460
- pre_edit_image = edit_preprocess(edit_info, we.device_id, pre_edit_image, pre_edit_mask)
 
 
 
461
  edit_info = edit_info or {}
462
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
463
  st = time.time()
464
- image, seed = self.pipe(
465
  reference_image=pre_ref_image,
466
  edit_image=pre_edit_image,
467
  edit_mask=pre_edit_mask,
@@ -473,40 +472,52 @@ class DemoUI(object):
473
  guide_scale=cfg_scale,
474
  seed=seed,
475
  repainting_scale=repainting_scale,
476
- lora_path=model_path
 
 
477
  )
478
  et = time.time()
479
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
480
  if pre_edit_image is not None:
481
- ret_image = Image.composite(Image.new("RGB", pre_edit_image.size, (0, 0, 0)), pre_edit_image, pre_edit_mask)
 
482
  else:
483
  ret_image = None
 
 
 
484
  return (gr.Image(value=image), gr.Column(visible=True),
485
- gr.Image(value=pre_edit_image if pre_edit_image is not None else pre_ref_image),
 
486
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
487
  gr.Text(value=msg),
488
- gr.update(value=ret_image))
 
489
 
490
  with self.eg:
491
  self.example_edit_image = gr.Image(label='Edit Image',
492
- type='pil',
493
- image_mode='RGB',
494
- visible=False)
495
  self.example_edit_mask = gr.Image(label='Edit Image Mask',
496
- type='pil',
497
- image_mode='L',
498
- visible=False)
499
 
500
  self.examples = gr.Examples(
501
  fn=run_example,
502
  examples=self.all_examples,
503
  inputs=[
504
- self.task_type, self.edit_type, self.text, self.reference_image, self.example_edit_image,
505
- self.example_edit_mask, self.output_height, self.output_width, self.seed
 
506
  ],
507
  outputs=[self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
508
- self.edit_preprocess_mask_preview, self.generation_info_preview, self.edit_image],
509
- examples_per_page=6,
 
 
 
510
  cache_examples=False,
511
  run_on_click=True)
512
 
 
13
  import subprocess
14
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
15
  subprocess.run(shlex.split('pip install scepter'))
16
+ subprocess.run(shlex.split('pip install numpy==1.26'))
17
  from scepter.modules.transform.io import pillow_convert
18
  from scepter.modules.utils.config import Config
19
  from scepter.modules.utils.distribute import we
20
  from scepter.modules.utils.file_system import FS
21
+ from examples.examples import fft_examples
22
+ from inference.registry import INFERENCES
23
  from inference.utils import edit_preprocess
 
 
24
 
 
 
 
25
 
26
  fs_list = [
27
  Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
 
34
  FS.init_fs_client(one_fs)
35
 
36
  os.environ["FLUX_FILL_PATH"]="hf://black-forest-labs/FLUX.1-Fill-dev"
37
+ os.environ["ACE_PLUS_FFT_MODEL"]="hf://ali-vilab/ACE_Plus@ace_plus_fft.safetensors"
 
 
38
 
39
  FS.get_dir_to_local_dir(os.environ["FLUX_FILL_PATH"])
40
+ FS.get_from(os.environ["ACE_PLUS_FFT_MODEL"])
 
 
 
41
 
42
  csv.field_size_limit(sys.maxsize)
43
  refresh_sty = '\U0001f504' # 🔄
 
51
  class DemoUI(object):
52
  #@spaces.GPU(duration=60)
53
  def __init__(self,
54
+ infer_dir="./config/ace_plus_fft.yaml"
 
55
  ):
56
+ self.model_yamls = [infer_dir]
 
57
  self.model_choices = dict()
58
  self.default_model_name = ''
59
+ self.edit_type_dict = {}
60
+ self.edit_type_list = []
61
+ self.default_type_list = []
62
  for i in self.model_yamls:
63
  model_cfg = Config(load=True, cfg_file=i)
64
+ model_name = model_cfg.VERSION
65
  if model_cfg.IS_DEFAULT: self.default_model_name = model_name
66
  self.model_choices[model_name] = model_cfg
67
+ for preprocessor in model_cfg.get("PREPROCESSOR", []):
68
+ if preprocessor["TYPE"] in self.edit_type_dict:
69
+ continue
70
+ self.edit_type_dict[preprocessor["TYPE"]] = preprocessor
71
+ self.default_type_list.append(preprocessor["TYPE"])
72
  print('Models: ', self.model_choices.keys())
73
  assert len(self.model_choices) > 0
74
  if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
75
  self.model_name = self.default_model_name
76
  pipe_cfg = self.model_choices[self.default_model_name]
77
+ self.pipe = INFERENCES.build(pipe_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # reformat examples
79
  self.all_examples = [
80
  [
81
+ one_example["edit_type"], one_example["instruction"],
82
+ one_example["input_reference_image"], one_example["input_image"],
83
+ one_example["input_mask"], one_example["output_h"],
84
+ one_example["output_w"], one_example["seed"]
85
+ ]
86
+ for one_example in fft_examples
87
  ]
88
 
89
  def construct_edit_image(self, edit_image, edit_mask):
 
106
  else:
107
  return None
108
 
 
 
 
109
  def create_ui(self):
110
  with gr.Row(equal_height=True, visible=True):
111
  with gr.Column(scale=2):
 
122
  height=600,
123
  interactive=False,
124
  type='pil',
125
+ elem_id='preprocess_image',
126
+ label='edit image'
127
  )
128
 
129
  self.edit_preprocess_mask_preview = gr.Image(
130
  height=600,
131
  interactive=False,
132
  type='pil',
133
+ elem_id='preprocess_image_mask',
134
+ label='edit mask'
135
+ )
136
+
137
+ self.change_preprocess_preview = gr.Image(
138
+ height=600,
139
+ interactive=False,
140
+ type='pil',
141
+ elem_id='preprocess_change_image',
142
+ label='change image'
143
  )
144
  with gr.Row():
145
  instruction = """
146
  **Instruction**:
147
+ Users can perform reference generation or editing tasks by uploading reference images
148
+ and editing images. When uploading the editing image, various editing types are available
149
+ for selection. Users can choose different dimensions of information preservation,
150
+ such as edge information, color information, and more. Pre-processing information
151
+ can be viewed in the 'related input image' tab.
 
 
152
  """
153
  self.instruction = gr.Markdown(value=instruction)
154
+ with gr.Row():
155
+ self.icon = gr.Image(
156
+ value=None,
157
+ interactive=False,
158
+ height=150,
159
+ type='pil',
160
+ elem_id='icon',
161
+ label='icon'
162
+ )
163
  with gr.Row():
164
  self.model_name_dd = gr.Dropdown(
165
  choices=self.model_choices,
166
  value=self.default_model_name,
167
  label='Model Version')
168
+ self.edit_type = gr.Dropdown(choices=self.default_type_list,
169
  interactive=True,
170
+ value=self.default_type_list[0],
 
 
 
 
171
  label='Edit Type')
172
+ with gr.Row():
173
+ self.step = gr.Slider(minimum=1,
174
+ maximum=1000,
175
+ value=self.pipe.input.get("sample_steps", 20),
176
+ visible=self.pipe.input.get("sample_steps", None) is not None,
177
+ label='Sample Step')
178
+ self.cfg_scale = gr.Slider(
179
+ minimum=1.0,
180
+ maximum=100.0,
181
+ value=self.pipe.input.get("guide_scale", 4.5),
182
+ visible=self.pipe.input.get("guide_scale", None) is not None,
183
+ label='Guidance Scale')
184
+ self.seed = gr.Slider(minimum=-1,
185
+ maximum=1000000000000,
186
+ value=-1,
187
+ label='Seed')
188
+ self.output_height = gr.Slider(
189
+ minimum=256,
190
+ maximum=1440,
191
+ value=self.pipe.input.get("image_size", [1024, 1024])[0],
192
+ visible=self.pipe.input.get("image_size", None) is not None,
193
+ label='Output Height')
194
+ self.output_width = gr.Slider(
195
+ minimum=256,
196
+ maximum=1440,
197
+ value=self.pipe.input.get("image_size", [1024, 1024])[1],
198
+ visible=self.pipe.input.get("image_size", None) is not None,
199
+ label='Output Width')
200
+
201
+ self.repainting_scale = gr.Slider(
202
+ minimum=0.0,
203
+ maximum=1.0,
204
+ value=self.pipe.input.get("repainting_scale", 1.0),
205
+ visible=True,
206
+ label='Repainting Scale')
207
+ self.use_change = gr.Checkbox(
208
+ value=self.pipe.input.get("use_change", True),
209
+ visible=True,
210
+ label='Use Change')
211
+ self.keep_pixel = gr.Checkbox(
212
+ value=self.pipe.input.get("keep_pixel", True),
213
+ visible=True,
214
+ label='Keep Pixels')
215
+ self.keep_pixels_rate = gr.Slider(
216
+ minimum=0.5,
217
+ maximum=1.0,
218
+ value=0.8,
219
+ visible=True,
220
+ label='keep_pixel rate')
221
  with gr.Row():
222
  self.generation_info_preview = gr.Markdown(
223
  label='System Log.',
 
230
  placeholder='Input "@" find history of image',
231
  label='Instruction',
232
  container=False,
233
+ lines=1)
234
  with gr.Column(scale=2, min_width=100):
235
  with gr.Row():
236
  with gr.Column(scale=1, min_width=100):
237
+ self.chat_btn = gr.Button(value='Generate', variant="primary")
238
 
239
  with gr.Accordion(label='Advance', open=True):
240
  with gr.Row(visible=True):
 
261
  format="png"
262
  )
263
 
264
+ with gr.Row():
265
+ self.eg = gr.Column(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  def set_callbacks(self, *args, **kwargs):
268
  ########################################
 
277
  torch.cuda.empty_cache()
278
  torch.cuda.ipc_collect()
279
  pipe_cfg = self.model_choices[model_name]
280
+ self.pipe = INFERENCES.build(pipe_cfg)
 
 
281
  self.model_name = model_name
282
  lock.release()
283
 
284
  return (model_name, gr.update(),
285
  gr.Slider(
286
+ value=self.pipe.input.get("sample_steps", 20),
287
+ visible=self.pipe.input.get("sample_steps", None) is not None),
288
  gr.Slider(
289
  value=self.pipe.input.get("guide_scale", 4.5),
290
  visible=self.pipe.input.get("guide_scale", None) is not None),
291
  gr.Slider(
292
+ value=self.pipe.input.get("image_size", [1024, 1024])[0],
293
+ visible=self.pipe.input.get("image_size", None) is not None),
294
  gr.Slider(
295
+ value=self.pipe.input.get("image_size", [1024, 1024])[1],
296
+ visible=self.pipe.input.get("image_size", None) is not None),
297
  gr.Slider(value=self.pipe.input.get("repainting_scale", 1.0))
298
  )
299
 
 
308
  self.output_width,
309
  self.repainting_scale])
310
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def change_edit_type(edit_type):
312
  edit_info = self.edit_type_dict[edit_type]
313
  edit_info = edit_info or {}
314
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
315
+ return gr.Slider(value=repainting_scale)
 
 
 
 
316
 
317
  self.edit_type.change(change_edit_type, inputs=[self.edit_type], outputs=[self.repainting_scale])
318
 
319
+ def resize_image(image, h):
320
+ ow, oh = image.size
321
+ w = int(h * ow / oh)
322
+ image = image.resize((w, h), Image.LANCZOS)
323
+ return image
324
+
325
+ def preprocess_input(ref_image, edit_image_dict, preprocess=None):
326
  err_msg = ""
327
  is_suc = True
328
  if ref_image is not None:
 
338
  edit_image = None
339
  edit_mask = None
340
  elif np.sum(np.array(edit_mask)) < 1:
341
+ edit_image = pillow_convert(edit_image, "RGB")
342
+ w, h = edit_image.size
343
+ edit_mask = Image.new("L", (w, h), 255)
344
  else:
345
  edit_image = pillow_convert(edit_image, "RGB")
346
  edit_mask = Image.fromarray(edit_mask).convert('L')
 
348
  err_msg = "Please provide the reference image or edited image."
349
  return None, None, None, False, err_msg
350
  return edit_image, edit_mask, ref_image, is_suc, err_msg
351
+
352
  @spaces.GPU(duration=80)
353
  def run_chat(
354
+ prompt,
355
+ ref_image,
356
+ edit_image,
357
+ edit_type,
358
+ cfg_scale,
359
+ step,
360
+ seed,
361
+ output_h,
362
+ output_w,
363
+ repainting_scale,
364
+ use_change,
365
+ keep_pixel,
366
+ keep_pixels_rate,
367
+ progress=gr.Progress(track_tqdm=True)
368
  ):
 
 
369
  edit_info = self.edit_type_dict[edit_type]
 
 
 
 
 
 
 
 
370
  pre_edit_image, pre_edit_mask, pre_ref_image, is_suc, err_msg = preprocess_input(ref_image, edit_image)
371
+ icon = pre_edit_image or pre_ref_image
372
  if not is_suc:
373
  err_msg = f"<mark>{err_msg}</mark>"
374
  return (gr.Image(), gr.Column(visible=True),
375
+ gr.Image(),
376
  gr.Image(),
377
  gr.Image(),
378
  gr.Text(value=err_msg))
379
+ pre_edit_image = edit_preprocess(edit_info.ANNOTATOR, we.device_id, pre_edit_image, pre_edit_mask)
380
  # edit_image["background"] = pre_edit_image
381
  st = time.time()
382
+ image, edit_image, change_image, mask, seed = self.pipe(
383
  reference_image=pre_ref_image,
384
  edit_image=pre_edit_image,
385
  edit_mask=pre_edit_mask,
 
391
  guide_scale=cfg_scale,
392
  seed=seed,
393
  repainting_scale=repainting_scale,
394
+ use_change=use_change,
395
+ keep_pixels=keep_pixel,
396
+ keep_pixels_rate=keep_pixels_rate
397
  )
398
  et = time.time()
399
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
400
 
401
+ if icon is not None:
402
+ icon = resize_image(icon, 150)
403
+
404
  return (gr.Image(value=image), gr.Column(visible=True),
405
+ gr.Image(value=edit_image if edit_image is not None else edit_image),
406
+ gr.Image(value=change_image),
407
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
408
+ gr.Text(value=msg),
409
+ gr.Image(value=icon))
410
 
411
  chat_inputs = [
412
  self.reference_image,
413
  self.edit_image,
 
414
  self.edit_type,
415
  self.cfg_scale,
416
  self.step,
417
  self.seed,
418
  self.output_height,
419
  self.output_width,
420
+ self.repainting_scale,
421
+ self.use_change,
422
+ self.keep_pixel,
423
+ self.keep_pixels_rate
424
  ]
425
 
426
  chat_outputs = [
427
+ self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
428
+ self.change_preprocess_preview,
429
+ self.edit_preprocess_mask_preview, self.generation_info_preview,
430
+ self.icon
431
  ]
432
 
433
  self.chat_btn.click(run_chat,
 
441
  queue=True)
442
 
443
  @spaces.GPU(duration=80)
444
+ def run_example(edit_type, prompt, ref_image, edit_image, edit_mask,
445
+ output_h, output_w, seed, use_change, keep_pixel,
446
+ keep_pixels_rate,
447
+ progress=gr.Progress(track_tqdm=True)):
448
 
449
  step = self.pipe.input.get("sample_steps", 20)
450
  cfg_scale = self.pipe.input.get("guide_scale", 20)
 
451
  edit_info = self.edit_type_dict[edit_type]
452
 
453
  edit_image = self.construct_edit_image(edit_image, edit_mask)
454
 
455
+ pre_edit_image, pre_edit_mask, pre_ref_image, _, _ = preprocess_input(ref_image, edit_image)
456
+
457
+ icon = pre_edit_image or pre_ref_image
458
+
459
+ pre_edit_image = edit_preprocess(edit_info.ANNOTATOR, we.device_id, pre_edit_image, pre_edit_mask)
460
  edit_info = edit_info or {}
461
  repainting_scale = edit_info.get("REPAINTING_SCALE", 1.0)
462
  st = time.time()
463
+ image, edit_image, change_image, mask, seed = self.pipe(
464
  reference_image=pre_ref_image,
465
  edit_image=pre_edit_image,
466
  edit_mask=pre_edit_mask,
 
472
  guide_scale=cfg_scale,
473
  seed=seed,
474
  repainting_scale=repainting_scale,
475
+ use_change=use_change,
476
+ keep_pixels=keep_pixel,
477
+ keep_pixels_rate=keep_pixels_rate
478
  )
479
  et = time.time()
480
  msg = f"prompt: {prompt}; seed: {seed}; cost time: {et - st}s; repaiting scale: {repainting_scale}"
481
  if pre_edit_image is not None:
482
+ ret_image = Image.composite(Image.new("RGB", pre_edit_image.size, (0, 0, 0)), pre_edit_image,
483
+ pre_edit_mask)
484
  else:
485
  ret_image = None
486
+
487
+ if icon is not None:
488
+ icon = resize_image(icon, 150)
489
  return (gr.Image(value=image), gr.Column(visible=True),
490
+ gr.Image(value=edit_image if edit_image is not None else edit_image),
491
+ gr.Image(value=change_image),
492
  gr.Image(value=pre_edit_mask if pre_edit_mask is not None else None),
493
  gr.Text(value=msg),
494
+ gr.update(value=ret_image),
495
+ gr.Image(value=icon))
496
 
497
  with self.eg:
498
  self.example_edit_image = gr.Image(label='Edit Image',
499
+ type='pil',
500
+ image_mode='RGB',
501
+ visible=False)
502
  self.example_edit_mask = gr.Image(label='Edit Image Mask',
503
+ type='pil',
504
+ image_mode='L',
505
+ visible=False)
506
 
507
  self.examples = gr.Examples(
508
  fn=run_example,
509
  examples=self.all_examples,
510
  inputs=[
511
+ self.edit_type, self.text, self.reference_image, self.example_edit_image,
512
+ self.example_edit_mask, self.output_height, self.output_width, self.seed,
513
+ self.use_change, self.keep_pixel, self.keep_pixels_rate
514
  ],
515
  outputs=[self.gallery_image, self.edit_preprocess_panel, self.edit_preprocess_preview,
516
+ self.change_preprocess_preview,
517
+ self.edit_preprocess_mask_preview, self.generation_info_preview,
518
+ self.edit_image,
519
+ self.icon],
520
+ examples_per_page=15,
521
  cache_examples=False,
522
  run_on_click=True)
523
 
inference/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .ace_plus_diffusers import ACEPlusDiffuserInference
2
  from .ace_plus_inference import ACEInference
 
1
+ #from .ace_plus_diffusers import ACEPlusDiffuserInference
2
  from .ace_plus_inference import ACEInference