Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a0efccd
1
Parent(s):
fb6e008
modify inference
Browse files- examples/examples.py +154 -2
- infer_fft.py +178 -0
- infer_lora.py +228 -0
- inference/__init__.py +2 -0
- inference/ace_plus_diffusers.py +7 -3
- inference/ace_plus_inference.py +83 -0
- inference/registry.py +228 -0
- inference/utils.py +38 -11
examples/examples.py
CHANGED
@@ -2,9 +2,9 @@ all_examples = [
|
|
2 |
{
|
3 |
"input_image": None,
|
4 |
"input_mask": None,
|
5 |
-
"input_reference_image": "assets/samples/portrait/
|
6 |
"save_path": "examples/outputs/portrait_human_1.jpg",
|
7 |
-
"instruction": "
|
8 |
"output_h": 1024,
|
9 |
"output_w": 1024,
|
10 |
"seed": 4194866942,
|
@@ -78,4 +78,156 @@ all_examples = [
|
|
78 |
"edit_type": "repainting"
|
79 |
}
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
]
|
|
|
2 |
{
|
3 |
"input_image": None,
|
4 |
"input_mask": None,
|
5 |
+
"input_reference_image": "assets/samples/portrait/human_1.jpg",
|
6 |
"save_path": "examples/outputs/portrait_human_1.jpg",
|
7 |
+
"instruction": "Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.",
|
8 |
"output_h": 1024,
|
9 |
"output_w": 1024,
|
10 |
"seed": 4194866942,
|
|
|
78 |
"edit_type": "repainting"
|
79 |
}
|
80 |
|
81 |
+
]
|
82 |
+
|
83 |
+
fft_examples = [
|
84 |
+
{
|
85 |
+
"input_image": None,
|
86 |
+
"input_mask": None,
|
87 |
+
"input_reference_image": "./assets/samples/portrait/human_1.jpg",
|
88 |
+
"save_path": "examples/outputs/portrait_human_1.jpg",
|
89 |
+
"instruction": "Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.",
|
90 |
+
"output_h": 1024,
|
91 |
+
"output_w": 1024,
|
92 |
+
"seed": 10000000,
|
93 |
+
"repainting_scale": 1.0,
|
94 |
+
"edit_type": "repainting"
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"input_image": None,
|
98 |
+
"input_mask": None,
|
99 |
+
"input_reference_image": "./assets/samples/subject/subject_1.jpg",
|
100 |
+
"save_path": "examples/outputs/subject_subject_1.jpg",
|
101 |
+
"instruction": "Display the logo in a minimalist style printed in white on a matte black ceramic coffee mug, alongside a steaming cup of coffee on a cozy cafe table.",
|
102 |
+
"output_h": 1024,
|
103 |
+
"output_w": 1024,
|
104 |
+
"seed": 10000000,
|
105 |
+
"repainting_scale": 1.0,
|
106 |
+
"edit_type": "repainting"
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"input_image": "./assets/samples/application/photo_editing/1_2_edit.jpg",
|
110 |
+
"input_mask": "./assets/samples/application/photo_editing/1_2_m.webp",
|
111 |
+
"input_reference_image": "./assets/samples/application/photo_editing/1_ref.png",
|
112 |
+
"save_path": "examples/outputs/photo_editing_1.jpg",
|
113 |
+
"instruction": "The item is put on the table.",
|
114 |
+
"output_h": 1024,
|
115 |
+
"output_w": 1024,
|
116 |
+
"seed": 8006019,
|
117 |
+
"repainting_scale": 1.0,
|
118 |
+
"edit_type": "repainting"
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"input_image": "./assets/samples/application/logo_paste/1_1_edit.png",
|
122 |
+
"input_mask": "./assets/samples/application/logo_paste/1_1_m.png",
|
123 |
+
"input_reference_image": "assets/samples/application/logo_paste/1_ref.png",
|
124 |
+
"save_path": "examples/outputs/logo_paste_1.jpg",
|
125 |
+
"instruction": "The logo is printed on the headphones.",
|
126 |
+
"output_h": 1024,
|
127 |
+
"output_w": 1024,
|
128 |
+
"seed": 934582264,
|
129 |
+
"repainting_scale": 1.0,
|
130 |
+
"edit_type": "repainting"
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"input_image": "./assets/samples/application/try_on/1_1_edit.png",
|
134 |
+
"input_mask": "./assets/samples/application/try_on/1_1_m.png",
|
135 |
+
"input_reference_image": "assets/samples/application/try_on/1_ref.png",
|
136 |
+
"save_path": "examples/outputs/try_on_1.jpg",
|
137 |
+
"instruction": "The woman dresses this skirt.",
|
138 |
+
"output_h": 1024,
|
139 |
+
"output_w": 1024,
|
140 |
+
"seed": 934582264,
|
141 |
+
"repainting_scale": 1.0,
|
142 |
+
"edit_type": "repainting"
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"input_image": "./assets/samples/portrait/human_1.jpg",
|
146 |
+
"input_mask": "assets/samples/application/movie_poster/1_2_m.webp",
|
147 |
+
"input_reference_image": "assets/samples/application/movie_poster/1_ref.png",
|
148 |
+
"save_path": "examples/outputs/movie_poster_1.jpg",
|
149 |
+
"instruction": "{image}, the man faces the camera.",
|
150 |
+
"output_h": 1024,
|
151 |
+
"output_w": 1024,
|
152 |
+
"seed": 3999647,
|
153 |
+
"repainting_scale": 1.0,
|
154 |
+
"edit_type": "repainting"
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"input_image": "./assets/samples/application/sr/sr_tiger.png",
|
158 |
+
"input_mask": "./assets/samples/application/sr/sr_tiger_m.webp",
|
159 |
+
"input_reference_image": None,
|
160 |
+
"save_path": "examples/outputs/mario_recolorizing_1.jpg",
|
161 |
+
"instruction": "{image} features a close-up of a young, furry tiger cub on a rock. The tiger, which appears to be quite young, has distinctive orange, "
|
162 |
+
"black, and white striped fur, typical of tigers. The cub's eyes have a bright and curious expression, and its ears are perked up, "
|
163 |
+
"indicating alertness. The cub seems to be in the act of climbing or resting on the rock. The background is a blurred grassland with trees, "
|
164 |
+
"but the focus is on the cub, which is vividly colored while the rest of the image is in grayscale, drawing attention to the tiger's details."
|
165 |
+
" The photo captures a moment in the wild, depicting the charming and tenacious nature of this young tiger,"
|
166 |
+
" as well as its typical interaction with the environment.",
|
167 |
+
"output_h": 1024,
|
168 |
+
"output_w": 1024,
|
169 |
+
"seed": 199999,
|
170 |
+
"repainting_scale": 0.0,
|
171 |
+
"edit_type": "no_preprocess"
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"input_image": "./assets/samples/application/photo_editing/1_ref.png",
|
175 |
+
"input_mask": "./assets/samples/application/photo_editing/1_1_orm.webp",
|
176 |
+
"input_reference_image": None,
|
177 |
+
"save_path": "examples/outputs/mario_repainting_1.jpg",
|
178 |
+
"instruction": "a blue hand",
|
179 |
+
"output_h": 1024,
|
180 |
+
"output_w": 1024,
|
181 |
+
"seed": 63401,
|
182 |
+
"repainting_scale": 1.0,
|
183 |
+
"edit_type": "repainting"
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"input_image": "./assets/samples/application/photo_editing/1_ref.png",
|
187 |
+
"input_mask": "./assets/samples/application/photo_editing/1_1_rm.webp",
|
188 |
+
"input_reference_image": None,
|
189 |
+
"save_path": "examples/outputs/mario_repainting_2.jpg",
|
190 |
+
"instruction": "Mechanical hands like a robot",
|
191 |
+
"output_h": 1024,
|
192 |
+
"output_w": 1024,
|
193 |
+
"seed": 59107,
|
194 |
+
"repainting_scale": 1.0,
|
195 |
+
"edit_type": "repainting"
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"input_image": "./assets/samples/control/1_1.webp",
|
199 |
+
"input_mask": "./assets/samples/control/1_1_m.webp",
|
200 |
+
"input_reference_image": None,
|
201 |
+
"save_path": "examples/outputs/control_recolorizing.jpg",
|
202 |
+
"instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
|
203 |
+
"output_h": 1024,
|
204 |
+
"output_w": 1024,
|
205 |
+
"seed": 9652101,
|
206 |
+
"repainting_scale": 0.0,
|
207 |
+
"edit_type": "recolorizing"
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"input_image": "./assets/samples/control/1_1.webp",
|
211 |
+
"input_mask": "./assets/samples/control/1_1_m.webp",
|
212 |
+
"input_reference_image": None,
|
213 |
+
"save_path": "examples/outputs/control_depth.jpg",
|
214 |
+
"instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
|
215 |
+
"output_h": 1024,
|
216 |
+
"output_w": 1024,
|
217 |
+
"seed": 14979476,
|
218 |
+
"repainting_scale": 0.0,
|
219 |
+
"edit_type": "depth_repainting"
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"input_image": "./assets/samples/control/1_1.webp",
|
223 |
+
"input_mask": "./assets/samples/control/1_1_m.webp",
|
224 |
+
"input_reference_image": None,
|
225 |
+
"save_path": "examples/outputs/control_contour.jpg",
|
226 |
+
"instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
|
227 |
+
"output_h": 1024,
|
228 |
+
"output_w": 1024,
|
229 |
+
"seed": 4227292472,
|
230 |
+
"repainting_scale": 0.0,
|
231 |
+
"edit_type": "contour_repainting"
|
232 |
+
}
|
233 |
]
|
infer_fft.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import importlib
|
6 |
+
import io
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
from scepter.modules.transform.io import pillow_convert
|
12 |
+
from scepter.modules.utils.config import Config
|
13 |
+
from scepter.modules.utils.file_system import FS
|
14 |
+
|
15 |
+
if os.path.exists('__init__.py'):
|
16 |
+
package_name = 'scepter_ext'
|
17 |
+
spec = importlib.util.spec_from_file_location(package_name, '__init__.py')
|
18 |
+
package = importlib.util.module_from_spec(spec)
|
19 |
+
sys.modules[package_name] = package
|
20 |
+
spec.loader.exec_module(package)
|
21 |
+
|
22 |
+
from examples.examples import fft_examples as all_examples
|
23 |
+
from inference.registry import INFERENCES
|
24 |
+
fs_list = [
|
25 |
+
Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
|
26 |
+
Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False),
|
27 |
+
Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False),
|
28 |
+
Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False),
|
29 |
+
]
|
30 |
+
|
31 |
+
for one_fs in fs_list:
|
32 |
+
FS.init_fs_client(one_fs)
|
33 |
+
|
34 |
+
|
35 |
+
def run_one_case(pipe,
|
36 |
+
input_image = None,
|
37 |
+
input_mask = None,
|
38 |
+
input_reference_image = None,
|
39 |
+
save_path = "examples/output/example.png",
|
40 |
+
instruction = "",
|
41 |
+
output_h = 1024,
|
42 |
+
output_w = 1024,
|
43 |
+
seed = -1,
|
44 |
+
sample_steps = None,
|
45 |
+
guide_scale = None,
|
46 |
+
repainting_scale = None,
|
47 |
+
use_change=True,
|
48 |
+
keep_pixels=True,
|
49 |
+
keep_pixels_rate=0.8,
|
50 |
+
**kwargs):
|
51 |
+
if input_image is not None:
|
52 |
+
input_image = Image.open(io.BytesIO(FS.get_object(input_image)))
|
53 |
+
input_image = pillow_convert(input_image, "RGB")
|
54 |
+
if input_mask is not None:
|
55 |
+
input_mask = Image.open(io.BytesIO(FS.get_object(input_mask)))
|
56 |
+
input_mask = pillow_convert(input_mask, "L")
|
57 |
+
if input_reference_image is not None:
|
58 |
+
input_reference_image = Image.open(io.BytesIO(FS.get_object(input_reference_image)))
|
59 |
+
input_reference_image = pillow_convert(input_reference_image, "RGB")
|
60 |
+
print(repainting_scale)
|
61 |
+
image, _, _, _, seed = pipe(
|
62 |
+
reference_image=input_reference_image,
|
63 |
+
edit_image=input_image,
|
64 |
+
edit_mask=input_mask,
|
65 |
+
prompt=instruction,
|
66 |
+
output_height=output_h,
|
67 |
+
output_width=output_w,
|
68 |
+
sampler='flow_euler',
|
69 |
+
sample_steps=sample_steps or pipe.input.get("sample_steps", 28),
|
70 |
+
guide_scale=guide_scale or pipe.input.get("guide_scale", 50),
|
71 |
+
seed=seed,
|
72 |
+
repainting_scale=repainting_scale,
|
73 |
+
use_change=use_change,
|
74 |
+
keep_pixels=keep_pixels,
|
75 |
+
keep_pixels_rate=keep_pixels_rate
|
76 |
+
)
|
77 |
+
with FS.put_to(save_path) as local_path:
|
78 |
+
image.save(local_path)
|
79 |
+
return local_path, seed
|
80 |
+
|
81 |
+
|
82 |
+
def run():
|
83 |
+
parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
|
84 |
+
parser.add_argument('--instruction',
|
85 |
+
dest='instruction',
|
86 |
+
help='The instruction for editing or generating!',
|
87 |
+
default="")
|
88 |
+
parser.add_argument('--output_h',
|
89 |
+
dest='output_h',
|
90 |
+
help='The height of output image for generation tasks!',
|
91 |
+
type=int,
|
92 |
+
default=1024)
|
93 |
+
parser.add_argument('--output_w',
|
94 |
+
dest='output_w',
|
95 |
+
help='The width of output image for generation tasks!',
|
96 |
+
type=int,
|
97 |
+
default=1024)
|
98 |
+
parser.add_argument('--input_reference_image',
|
99 |
+
dest='input_reference_image',
|
100 |
+
help='The input reference image!',
|
101 |
+
default=None
|
102 |
+
)
|
103 |
+
parser.add_argument('--input_image',
|
104 |
+
dest='input_image',
|
105 |
+
help='The input image!',
|
106 |
+
default=None
|
107 |
+
)
|
108 |
+
parser.add_argument('--input_mask',
|
109 |
+
dest='input_mask',
|
110 |
+
help='The input mask!',
|
111 |
+
default=None
|
112 |
+
)
|
113 |
+
parser.add_argument('--save_path',
|
114 |
+
dest='save_path',
|
115 |
+
help='The save path for output image!',
|
116 |
+
default='examples/output_images/output.png'
|
117 |
+
)
|
118 |
+
parser.add_argument('--seed',
|
119 |
+
dest='seed',
|
120 |
+
help='The seed for generation!',
|
121 |
+
type=int,
|
122 |
+
default=-1)
|
123 |
+
|
124 |
+
parser.add_argument('--step',
|
125 |
+
dest='step',
|
126 |
+
help='The sample step for generation!',
|
127 |
+
type=int,
|
128 |
+
default=None)
|
129 |
+
|
130 |
+
parser.add_argument('--guide_scale',
|
131 |
+
dest='guide_scale',
|
132 |
+
help='The guide scale for generation!',
|
133 |
+
type=int,
|
134 |
+
default=None)
|
135 |
+
|
136 |
+
parser.add_argument('--repainting_scale',
|
137 |
+
dest='repainting_scale',
|
138 |
+
help='The repainting scale for content filling generation!',
|
139 |
+
type=int,
|
140 |
+
default=None)
|
141 |
+
|
142 |
+
cfg = Config(load=True, parser_ins=parser)
|
143 |
+
model_cfg = Config(load=True, cfg_file="config/ace_plus_fft.yaml")
|
144 |
+
pipe = INFERENCES.build(model_cfg)
|
145 |
+
|
146 |
+
|
147 |
+
if cfg.args.instruction == "" and cfg.args.input_image is None and cfg.args.input_reference_image is None:
|
148 |
+
params = {
|
149 |
+
"output_h": cfg.args.output_h,
|
150 |
+
"output_w": cfg.args.output_w,
|
151 |
+
"sample_steps": cfg.args.step,
|
152 |
+
"guide_scale": cfg.args.guide_scale
|
153 |
+
}
|
154 |
+
# run examples
|
155 |
+
|
156 |
+
for example in all_examples:
|
157 |
+
example.update(params)
|
158 |
+
local_path, seed = run_one_case(pipe, **example)
|
159 |
+
|
160 |
+
else:
|
161 |
+
params = {
|
162 |
+
"input_image": cfg.args.input_image,
|
163 |
+
"input_mask": cfg.args.input_mask,
|
164 |
+
"input_reference_image": cfg.args.input_reference_image,
|
165 |
+
"save_path": cfg.args.save_path,
|
166 |
+
"instruction": cfg.args.instruction,
|
167 |
+
"output_h": cfg.args.output_h,
|
168 |
+
"output_w": cfg.args.output_w,
|
169 |
+
"sample_steps": cfg.args.step,
|
170 |
+
"guide_scale": cfg.args.guide_scale,
|
171 |
+
"repainting_scale": cfg.args.repainting_scale,
|
172 |
+
}
|
173 |
+
local_path, seed = run_one_case(pipe, **params)
|
174 |
+
print(local_path, seed)
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
run()
|
178 |
+
|
infer_lora.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from scepter.modules.transform.io import pillow_convert
|
10 |
+
from scepter.modules.utils.config import Config
|
11 |
+
from scepter.modules.utils.file_system import FS
|
12 |
+
|
13 |
+
from examples.examples import all_examples
|
14 |
+
from inference.ace_plus_diffusers import ACEPlusDiffuserInference
|
15 |
+
inference_dict = {
|
16 |
+
"ACE_DIFFUSER_PLUS": ACEPlusDiffuserInference
|
17 |
+
}
|
18 |
+
|
19 |
+
fs_list = [
|
20 |
+
Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
|
21 |
+
Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False),
|
22 |
+
Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False),
|
23 |
+
Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False),
|
24 |
+
]
|
25 |
+
|
26 |
+
for one_fs in fs_list:
|
27 |
+
FS.init_fs_client(one_fs)
|
28 |
+
|
29 |
+
|
30 |
+
def run_one_case(pipe,
|
31 |
+
input_image = None,
|
32 |
+
input_mask = None,
|
33 |
+
input_reference_image = None,
|
34 |
+
save_path = "examples/output/example.png",
|
35 |
+
instruction = "",
|
36 |
+
output_h = 1024,
|
37 |
+
output_w = 1024,
|
38 |
+
seed = -1,
|
39 |
+
sample_steps = None,
|
40 |
+
guide_scale = None,
|
41 |
+
repainting_scale = None,
|
42 |
+
model_path = None,
|
43 |
+
**kwargs):
|
44 |
+
if input_image is not None:
|
45 |
+
input_image = Image.open(io.BytesIO(FS.get_object(input_image)))
|
46 |
+
input_image = pillow_convert(input_image, "RGB")
|
47 |
+
if input_mask is not None:
|
48 |
+
input_mask = Image.open(io.BytesIO(FS.get_object(input_mask)))
|
49 |
+
input_mask = pillow_convert(input_mask, "L")
|
50 |
+
if input_reference_image is not None:
|
51 |
+
input_reference_image = Image.open(io.BytesIO(FS.get_object(input_reference_image)))
|
52 |
+
input_reference_image = pillow_convert(input_reference_image, "RGB")
|
53 |
+
|
54 |
+
image, seed = pipe(
|
55 |
+
reference_image=input_reference_image,
|
56 |
+
edit_image=input_image,
|
57 |
+
edit_mask=input_mask,
|
58 |
+
prompt=instruction,
|
59 |
+
output_height=output_h,
|
60 |
+
output_width=output_w,
|
61 |
+
sampler='flow_euler',
|
62 |
+
sample_steps=sample_steps or pipe.input.get("sample_steps", 28),
|
63 |
+
guide_scale=guide_scale or pipe.input.get("guide_scale", 50),
|
64 |
+
seed=seed,
|
65 |
+
repainting_scale=repainting_scale or pipe.input.get("repainting_scale", 1.0),
|
66 |
+
lora_path = model_path
|
67 |
+
)
|
68 |
+
with FS.put_to(save_path) as local_path:
|
69 |
+
image.save(local_path)
|
70 |
+
return local_path, seed
|
71 |
+
|
72 |
+
|
73 |
+
def run():
|
74 |
+
parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
|
75 |
+
parser.add_argument('--instruction',
|
76 |
+
dest='instruction',
|
77 |
+
help='The instruction for editing or generating!',
|
78 |
+
default="")
|
79 |
+
parser.add_argument('--output_h',
|
80 |
+
dest='output_h',
|
81 |
+
help='The height of output image for generation tasks!',
|
82 |
+
type=int,
|
83 |
+
default=1024)
|
84 |
+
parser.add_argument('--output_w',
|
85 |
+
dest='output_w',
|
86 |
+
help='The width of output image for generation tasks!',
|
87 |
+
type=int,
|
88 |
+
default=1024)
|
89 |
+
parser.add_argument('--input_reference_image',
|
90 |
+
dest='input_reference_image',
|
91 |
+
help='The input reference image!',
|
92 |
+
default=None
|
93 |
+
)
|
94 |
+
parser.add_argument('--input_image',
|
95 |
+
dest='input_image',
|
96 |
+
help='The input image!',
|
97 |
+
default=None
|
98 |
+
)
|
99 |
+
parser.add_argument('--input_mask',
|
100 |
+
dest='input_mask',
|
101 |
+
help='The input mask!',
|
102 |
+
default=None
|
103 |
+
)
|
104 |
+
parser.add_argument('--save_path',
|
105 |
+
dest='save_path',
|
106 |
+
help='The save path for output image!',
|
107 |
+
default='examples/output_images/output.png'
|
108 |
+
)
|
109 |
+
parser.add_argument('--seed',
|
110 |
+
dest='seed',
|
111 |
+
help='The seed for generation!',
|
112 |
+
type=int,
|
113 |
+
default=-1)
|
114 |
+
|
115 |
+
parser.add_argument('--step',
|
116 |
+
dest='step',
|
117 |
+
help='The sample step for generation!',
|
118 |
+
type=int,
|
119 |
+
default=None)
|
120 |
+
|
121 |
+
parser.add_argument('--guide_scale',
|
122 |
+
dest='guide_scale',
|
123 |
+
help='The guide scale for generation!',
|
124 |
+
type=int,
|
125 |
+
default=None)
|
126 |
+
|
127 |
+
parser.add_argument('--repainting_scale',
|
128 |
+
dest='repainting_scale',
|
129 |
+
help='The repainting scale for content filling generation!',
|
130 |
+
type=int,
|
131 |
+
default=None)
|
132 |
+
|
133 |
+
parser.add_argument('--task_type',
|
134 |
+
dest='task_type',
|
135 |
+
choices=['portrait', 'subject', 'local_editing'],
|
136 |
+
help="Choose the task type.",
|
137 |
+
default='')
|
138 |
+
|
139 |
+
parser.add_argument('--task_model',
|
140 |
+
dest='task_model',
|
141 |
+
help='The models list for different tasks!',
|
142 |
+
default="./models/model_zoo.yaml")
|
143 |
+
|
144 |
+
|
145 |
+
parser.add_argument('--infer_type',
|
146 |
+
dest='infer_type',
|
147 |
+
choices=['diffusers'],
|
148 |
+
default='diffusers',
|
149 |
+
help="Choose the inference scripts. 'native' refers to using the official implementation of ace++, "
|
150 |
+
"while 'diffusers' refers to using the adaptation for diffusers")
|
151 |
+
|
152 |
+
parser.add_argument('--cfg_folder',
|
153 |
+
dest='cfg_folder',
|
154 |
+
help='The inference config!',
|
155 |
+
default="./config")
|
156 |
+
|
157 |
+
cfg = Config(load=True, parser_ins=parser)
|
158 |
+
|
159 |
+
model_yamls = glob.glob(os.path.join(cfg.args.cfg_folder, '*.yaml'))
|
160 |
+
model_choices = dict()
|
161 |
+
for i in model_yamls:
|
162 |
+
model_cfg = Config(load=True, cfg_file=i)
|
163 |
+
model_name = model_cfg.NAME
|
164 |
+
model_choices[model_name] = model_cfg
|
165 |
+
|
166 |
+
if cfg.args.infer_type == "native":
|
167 |
+
infer_name = "ace_plus_native_infer"
|
168 |
+
elif cfg.args.infer_type == "diffusers":
|
169 |
+
infer_name = "ace_plus_diffuser_infer"
|
170 |
+
else:
|
171 |
+
raise ValueError("infer_type should be native or diffusers")
|
172 |
+
|
173 |
+
assert infer_name in model_choices
|
174 |
+
|
175 |
+
# choose different model
|
176 |
+
task_model_cfg = Config(load=True, cfg_file=cfg.args.task_model)
|
177 |
+
|
178 |
+
task_model_dict = {}
|
179 |
+
for task_name, task_model in task_model_cfg.MODEL.items():
|
180 |
+
task_model_dict[task_name] = task_model
|
181 |
+
|
182 |
+
|
183 |
+
# choose the inference scripts.
|
184 |
+
pipe_cfg = model_choices[infer_name]
|
185 |
+
infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE_PLUS")
|
186 |
+
pipe = inference_dict[infer_name]()
|
187 |
+
pipe.init_from_cfg(pipe_cfg)
|
188 |
+
|
189 |
+
if cfg.args.instruction == "" and cfg.args.input_image is None and cfg.args.input_reference_image is None:
|
190 |
+
params = {
|
191 |
+
"output_h": cfg.args.output_h,
|
192 |
+
"output_w": cfg.args.output_w,
|
193 |
+
"sample_steps": cfg.args.step,
|
194 |
+
"guide_scale": cfg.args.guide_scale
|
195 |
+
}
|
196 |
+
# run examples
|
197 |
+
|
198 |
+
for example in all_examples:
|
199 |
+
example["model_path"] = FS.get_from(task_model_dict[example["task_type"].upper()]["MODEL_PATH"])
|
200 |
+
example.update(params)
|
201 |
+
if example["edit_type"] == "repainting":
|
202 |
+
example["repainting_scale"] = 1.0
|
203 |
+
else:
|
204 |
+
example["repainting_scale"] = task_model_dict[example["task_type"].upper()].get("REPAINTING_SCALE", 1.0)
|
205 |
+
print(example)
|
206 |
+
local_path, seed = run_one_case(pipe, **example)
|
207 |
+
|
208 |
+
else:
|
209 |
+
assert cfg.args.task_type.upper() in task_model_cfg
|
210 |
+
params = {
|
211 |
+
"input_image": cfg.args.input_image,
|
212 |
+
"input_mask": cfg.args.input_mask,
|
213 |
+
"input_reference_image": cfg.args.input_reference_image,
|
214 |
+
"save_path": cfg.args.save_path,
|
215 |
+
"instruction": cfg.args.instruction,
|
216 |
+
"output_h": cfg.args.output_h,
|
217 |
+
"output_w": cfg.args.output_w,
|
218 |
+
"sample_steps": cfg.args.step,
|
219 |
+
"guide_scale": cfg.args.guide_scale,
|
220 |
+
"repainting_scale": cfg.args.repainting_scale,
|
221 |
+
"model_path": FS.get_from(task_model_dict[cfg.args.task_type.upper()]["MODEL_PATH"])
|
222 |
+
}
|
223 |
+
local_path, seed = run_one_case(pipe, **params)
|
224 |
+
print(local_path, seed)
|
225 |
+
|
226 |
+
if __name__ == '__main__':
|
227 |
+
run()
|
228 |
+
|
inference/__init__.py
CHANGED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .ace_plus_diffusers import ACEPlusDiffuserInference
|
2 |
+
from .ace_plus_inference import ACEInference
|
inference/ace_plus_diffusers.py
CHANGED
@@ -12,7 +12,6 @@ from scepter.modules.utils.logger import get_logger
|
|
12 |
from transformers import T5TokenizerFast
|
13 |
from .utils import ACEPlusImageProcessor
|
14 |
|
15 |
-
|
16 |
class ACEPlusDiffuserInference():
|
17 |
def __init__(self, logger=None):
|
18 |
if logger is None:
|
@@ -39,7 +38,6 @@ class ACEPlusDiffuserInference():
|
|
39 |
self.pipe.tokenizer_2 = tokenizer_2
|
40 |
self.load_default(cfg.DEFAULT_PARAS)
|
41 |
|
42 |
-
|
43 |
def prepare_input(self,
|
44 |
image,
|
45 |
mask,
|
@@ -88,7 +86,11 @@ class ACEPlusDiffuserInference():
|
|
88 |
if isinstance(prompt, str):
|
89 |
prompt = [prompt]
|
90 |
seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
|
91 |
-
|
|
|
|
|
|
|
|
|
92 |
h, w = image.shape[1:]
|
93 |
generator = torch.Generator("cpu").manual_seed(seed)
|
94 |
masked_image_latents = self.prepare_input(image, mask,
|
@@ -98,6 +100,8 @@ class ACEPlusDiffuserInference():
|
|
98 |
with FS.get_from(lora_path) as local_path:
|
99 |
self.pipe.load_lora_weights(local_path)
|
100 |
|
|
|
|
|
101 |
image = self.pipe(
|
102 |
prompt=prompt,
|
103 |
masked_image_latents=masked_image_latents,
|
|
|
12 |
from transformers import T5TokenizerFast
|
13 |
from .utils import ACEPlusImageProcessor
|
14 |
|
|
|
15 |
class ACEPlusDiffuserInference():
|
16 |
def __init__(self, logger=None):
|
17 |
if logger is None:
|
|
|
38 |
self.pipe.tokenizer_2 = tokenizer_2
|
39 |
self.load_default(cfg.DEFAULT_PARAS)
|
40 |
|
|
|
41 |
def prepare_input(self,
|
42 |
image,
|
43 |
mask,
|
|
|
86 |
if isinstance(prompt, str):
|
87 |
prompt = [prompt]
|
88 |
seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
|
89 |
+
# edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
|
90 |
+
image, mask, _, _, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
|
91 |
+
width = output_width,
|
92 |
+
height = output_height,
|
93 |
+
repainting_scale = repainting_scale)
|
94 |
h, w = image.shape[1:]
|
95 |
generator = torch.Generator("cpu").manual_seed(seed)
|
96 |
masked_image_latents = self.prepare_input(image, mask,
|
|
|
100 |
with FS.get_from(lora_path) as local_path:
|
101 |
self.pipe.load_lora_weights(local_path)
|
102 |
|
103 |
+
|
104 |
+
|
105 |
image = self.pipe(
|
106 |
prompt=prompt,
|
107 |
masked_image_latents=masked_image_latents,
|
inference/ace_plus_inference.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import random
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import torch, numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from scepter.modules.model.registry import MODELS
|
9 |
+
from scepter.modules.utils.config import Config
|
10 |
+
from scepter.modules.utils.distribute import we
|
11 |
+
from .registry import BaseInference, INFERENCES
|
12 |
+
from .utils import ACEPlusImageProcessor
|
13 |
+
|
14 |
+
@INFERENCES.register_class()
|
15 |
+
class ACEInference(BaseInference):
|
16 |
+
'''
|
17 |
+
reuse the ldm code
|
18 |
+
'''
|
19 |
+
def __init__(self, cfg, logger=None):
|
20 |
+
super().__init__(cfg, logger)
|
21 |
+
self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id)
|
22 |
+
self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN)
|
23 |
+
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for
|
24 |
+
k, v in cfg.SAMPLE_ARGS.items()}
|
25 |
+
self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16"))
|
26 |
+
@torch.no_grad()
|
27 |
+
def __call__(self,
|
28 |
+
reference_image=None,
|
29 |
+
edit_image=None,
|
30 |
+
edit_mask=None,
|
31 |
+
prompt='',
|
32 |
+
edit_type=None,
|
33 |
+
output_height=1024,
|
34 |
+
output_width=1024,
|
35 |
+
sampler='flow_euler',
|
36 |
+
sample_steps=28,
|
37 |
+
guide_scale=50,
|
38 |
+
lora_path=None,
|
39 |
+
seed=-1,
|
40 |
+
repainting_scale=0,
|
41 |
+
use_change=False,
|
42 |
+
keep_pixels=False,
|
43 |
+
keep_pixels_rate=0.8,
|
44 |
+
**kwargs):
|
45 |
+
# convert the input info to the input of ldm.
|
46 |
+
if isinstance(prompt, str):
|
47 |
+
prompt = [prompt]
|
48 |
+
seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1)
|
49 |
+
image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
|
50 |
+
height=output_height, width=output_width,
|
51 |
+
repainting_scale=repainting_scale,
|
52 |
+
keep_pixels=keep_pixels,
|
53 |
+
keep_pixels_rate=keep_pixels_rate,
|
54 |
+
use_change = use_change)
|
55 |
+
change_image = [None] if change_image is None else [change_image.to(we.device_id)]
|
56 |
+
image, mask = [image.to(we.device_id)], [mask.to(we.device_id)]
|
57 |
+
|
58 |
+
(src_image_list, src_mask_list, modify_image_list,
|
59 |
+
edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt]
|
60 |
+
|
61 |
+
with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'):
|
62 |
+
out_image = self.pipe(
|
63 |
+
src_image_list=src_image_list,
|
64 |
+
modify_image_list= modify_image_list,
|
65 |
+
src_mask_list=src_mask_list,
|
66 |
+
edit_id=edit_id,
|
67 |
+
image=image,
|
68 |
+
image_mask=mask,
|
69 |
+
prompt=prompt,
|
70 |
+
sampler='flow_euler',
|
71 |
+
sample_steps=sample_steps,
|
72 |
+
seed=seed,
|
73 |
+
guide_scale=guide_scale,
|
74 |
+
show_process=True,
|
75 |
+
)
|
76 |
+
imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy()
|
77 |
+
for x_i in out_image
|
78 |
+
]
|
79 |
+
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
|
80 |
+
edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
81 |
+
change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
82 |
+
mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8))
|
83 |
+
return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed
|
inference/registry.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL.Image import Image
|
6 |
+
from collections import OrderedDict
|
7 |
+
from scepter.modules.utils.distribute import we
|
8 |
+
from scepter.modules.utils.config import Config
|
9 |
+
from scepter.modules.utils.logger import get_logger
|
10 |
+
from scepter.studio.utils.env import get_available_memory
|
11 |
+
from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS
|
12 |
+
from scepter.modules.utils.registry import Registry, build_from_config
|
13 |
+
def get_model(model_tuple):
|
14 |
+
assert 'model' in model_tuple
|
15 |
+
return model_tuple['model']
|
16 |
+
|
17 |
+
class BaseInference():
|
18 |
+
'''
|
19 |
+
support to load the components dynamicly.
|
20 |
+
create and load model when run this model at the first time.
|
21 |
+
'''
|
22 |
+
def __init__(self, cfg, logger=None):
|
23 |
+
if logger is None:
|
24 |
+
logger = get_logger(name='scepter')
|
25 |
+
self.logger = logger
|
26 |
+
self.name = cfg.NAME
|
27 |
+
|
28 |
+
def init_from_modules(self, modules):
|
29 |
+
for k, v in modules.items():
|
30 |
+
self.__setattr__(k, v)
|
31 |
+
|
32 |
+
def infer_model(self, cfg, module_paras=None):
|
33 |
+
module = {
|
34 |
+
'model': None,
|
35 |
+
'cfg': cfg,
|
36 |
+
'device': 'offline',
|
37 |
+
'name': cfg.NAME,
|
38 |
+
'function_info': {},
|
39 |
+
'paras': {}
|
40 |
+
}
|
41 |
+
if module_paras is None:
|
42 |
+
return module
|
43 |
+
function_info = {}
|
44 |
+
paras = {
|
45 |
+
k.lower(): v
|
46 |
+
for k, v in module_paras.get('PARAS', {}).items()
|
47 |
+
}
|
48 |
+
for function in module_paras.get('FUNCTION', []):
|
49 |
+
input_dict = {}
|
50 |
+
for inp in function.get('INPUT', []):
|
51 |
+
if inp.lower() in self.input:
|
52 |
+
input_dict[inp.lower()] = self.input[inp.lower()]
|
53 |
+
function_info[function.NAME] = {
|
54 |
+
'dtype': function.get('DTYPE', 'float32'),
|
55 |
+
'input': input_dict
|
56 |
+
}
|
57 |
+
module['paras'] = paras
|
58 |
+
module['function_info'] = function_info
|
59 |
+
return module
|
60 |
+
|
61 |
+
def init_from_ckpt(self, path, model, ignore_keys=list()):
|
62 |
+
if path.endswith('safetensors'):
|
63 |
+
from safetensors.torch import load_file as load_safetensors
|
64 |
+
sd = load_safetensors(path)
|
65 |
+
else:
|
66 |
+
sd = torch.load(path, map_location='cpu', weights_only=True)
|
67 |
+
|
68 |
+
new_sd = OrderedDict()
|
69 |
+
for k, v in sd.items():
|
70 |
+
ignored = False
|
71 |
+
for ik in ignore_keys:
|
72 |
+
if ik in k:
|
73 |
+
if we.rank == 0:
|
74 |
+
self.logger.info(
|
75 |
+
'Ignore key {} from state_dict.'.format(k))
|
76 |
+
ignored = True
|
77 |
+
break
|
78 |
+
if not ignored:
|
79 |
+
new_sd[k] = v
|
80 |
+
|
81 |
+
missing, unexpected = model.load_state_dict(new_sd, strict=False)
|
82 |
+
if we.rank == 0:
|
83 |
+
self.logger.info(
|
84 |
+
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
85 |
+
)
|
86 |
+
if len(missing) > 0:
|
87 |
+
self.logger.info(f'Missing Keys:\n {missing}')
|
88 |
+
if len(unexpected) > 0:
|
89 |
+
self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
|
90 |
+
|
91 |
+
def load(self, module):
|
92 |
+
if module['device'] == 'offline':
|
93 |
+
from scepter.modules.utils.import_utils import LazyImportModule
|
94 |
+
if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or
|
95 |
+
module['cfg'].NAME in MODELS.class_map):
|
96 |
+
model = MODELS.build(module['cfg'], logger=self.logger).eval()
|
97 |
+
elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or
|
98 |
+
module['cfg'].NAME in BACKBONES.class_map):
|
99 |
+
model = BACKBONES.build(module['cfg'],
|
100 |
+
logger=self.logger).eval()
|
101 |
+
elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or
|
102 |
+
module['cfg'].NAME in EMBEDDERS.class_map):
|
103 |
+
model = EMBEDDERS.build(module['cfg'],
|
104 |
+
logger=self.logger).eval()
|
105 |
+
else:
|
106 |
+
raise NotImplementedError
|
107 |
+
if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None:
|
108 |
+
model = model.to(getattr(torch, module['cfg'].DTYPE))
|
109 |
+
if module['cfg'].get('RELOAD_MODEL', None):
|
110 |
+
self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model)
|
111 |
+
module['model'] = model
|
112 |
+
module['device'] = 'cpu'
|
113 |
+
if module['device'] == 'cpu':
|
114 |
+
module['device'] = we.device_id
|
115 |
+
module['model'] = module['model'].to(we.device_id)
|
116 |
+
return module
|
117 |
+
|
118 |
+
def unload(self, module):
|
119 |
+
if module is None:
|
120 |
+
return module
|
121 |
+
mem = get_available_memory()
|
122 |
+
free_mem = int(mem['available'] / (1024**2))
|
123 |
+
total_mem = int(mem['total'] / (1024**2))
|
124 |
+
if free_mem < 0.5 * total_mem:
|
125 |
+
if module['model'] is not None:
|
126 |
+
module['model'] = module['model'].to('cpu')
|
127 |
+
del module['model']
|
128 |
+
module['model'] = None
|
129 |
+
module['device'] = 'offline'
|
130 |
+
print('delete module')
|
131 |
+
else:
|
132 |
+
if module['model'] is not None:
|
133 |
+
module['model'] = module['model'].to('cpu')
|
134 |
+
module['device'] = 'cpu'
|
135 |
+
else:
|
136 |
+
module['device'] = 'offline'
|
137 |
+
if torch.cuda.is_available():
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
torch.cuda.ipc_collect()
|
140 |
+
return module
|
141 |
+
|
142 |
+
def dynamic_load(self, module=None, name=''):
|
143 |
+
self.logger.info('Loading {} model'.format(name))
|
144 |
+
if name == 'all':
|
145 |
+
for subname in self.loaded_model_name:
|
146 |
+
self.loaded_model[subname] = self.dynamic_load(
|
147 |
+
getattr(self, subname), subname)
|
148 |
+
elif name in self.loaded_model_name:
|
149 |
+
if name in self.loaded_model:
|
150 |
+
if module['cfg'] != self.loaded_model[name]['cfg']:
|
151 |
+
self.unload(self.loaded_model[name])
|
152 |
+
module = self.load(module)
|
153 |
+
self.loaded_model[name] = module
|
154 |
+
return module
|
155 |
+
elif module['device'] == 'cpu' or module['device'] == 'offline':
|
156 |
+
module = self.load(module)
|
157 |
+
return module
|
158 |
+
else:
|
159 |
+
return module
|
160 |
+
else:
|
161 |
+
module = self.load(module)
|
162 |
+
self.loaded_model[name] = module
|
163 |
+
return module
|
164 |
+
else:
|
165 |
+
return self.load(module)
|
166 |
+
|
167 |
+
def dynamic_unload(self, module=None, name='', skip_loaded=False):
|
168 |
+
self.logger.info('Unloading {} model'.format(name))
|
169 |
+
if name == 'all':
|
170 |
+
for name, module in self.loaded_model.items():
|
171 |
+
module = self.unload(self.loaded_model[name])
|
172 |
+
self.loaded_model[name] = module
|
173 |
+
elif name in self.loaded_model_name:
|
174 |
+
if name in self.loaded_model:
|
175 |
+
if not skip_loaded:
|
176 |
+
module = self.unload(self.loaded_model[name])
|
177 |
+
self.loaded_model[name] = module
|
178 |
+
else:
|
179 |
+
self.unload(module)
|
180 |
+
else:
|
181 |
+
self.unload(module)
|
182 |
+
|
183 |
+
def load_default(self, cfg):
|
184 |
+
module_paras = {}
|
185 |
+
if cfg is not None:
|
186 |
+
self.paras = cfg.PARAS
|
187 |
+
self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
|
188 |
+
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
|
189 |
+
self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
|
190 |
+
module_paras = cfg.MODULES_PARAS
|
191 |
+
return module_paras
|
192 |
+
|
193 |
+
def load_image(self, image, num_samples=1):
|
194 |
+
if isinstance(image, torch.Tensor):
|
195 |
+
pass
|
196 |
+
elif isinstance(image, Image):
|
197 |
+
pass
|
198 |
+
elif isinstance(image, Image):
|
199 |
+
pass
|
200 |
+
|
201 |
+
def get_function_info(self, module, function_name=None):
|
202 |
+
all_function = module['function_info']
|
203 |
+
if function_name in all_function:
|
204 |
+
return function_name, all_function[function_name]['dtype']
|
205 |
+
if function_name is None and len(all_function) == 1:
|
206 |
+
for k, v in all_function.items():
|
207 |
+
return k, v['dtype']
|
208 |
+
|
209 |
+
@torch.no_grad()
|
210 |
+
def __call__(self,
|
211 |
+
input,
|
212 |
+
**kwargs):
|
213 |
+
return
|
214 |
+
|
215 |
+
def build_inference(cfg, registry, logger=None, *args, **kwargs):
|
216 |
+
""" After build model, load pretrained model if exists key `pretrain`.
|
217 |
+
|
218 |
+
pretrain (str, dict): Describes how to load pretrained model.
|
219 |
+
str, treat pretrain as model path;
|
220 |
+
dict: should contains key `path`, and other parameters token by function load_pretrained();
|
221 |
+
"""
|
222 |
+
if not isinstance(cfg, Config):
|
223 |
+
raise TypeError(f'Config must be type dict, got {type(cfg)}')
|
224 |
+
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
|
225 |
+
return model
|
226 |
+
|
227 |
+
# reigister cls for diffusion.
|
228 |
+
INFERENCES = Registry('INFERENCE', build_func=build_inference)
|
inference/utils.py
CHANGED
@@ -49,7 +49,10 @@ class ACEPlusImageProcessor():
|
|
49 |
edit_mask=None,
|
50 |
height=1024,
|
51 |
width=1024,
|
52 |
-
repainting_scale = 1.0
|
|
|
|
|
|
|
53 |
reference_image = self.image_check(reference_image)
|
54 |
edit_image = self.image_check(edit_image)
|
55 |
# for reference generation
|
@@ -57,8 +60,12 @@ class ACEPlusImageProcessor():
|
|
57 |
edit_image = torch.zeros([3, height, width])
|
58 |
edit_mask = torch.ones([1, height, width])
|
59 |
else:
|
60 |
-
edit_mask
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
edit_mask = edit_mask.astype(
|
63 |
np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
|
64 |
np.float32)
|
@@ -71,12 +78,27 @@ class ACEPlusImageProcessor():
|
|
71 |
|
72 |
assert edit_mask is not None
|
73 |
if reference_image is not None:
|
74 |
-
# align height with edit_image
|
75 |
_, H, W = reference_image.shape
|
76 |
_, eH, eW = edit_image.shape
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
edit_image = torch.cat([reference_image, edit_image], dim=-1)
|
81 |
edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
|
82 |
slice_w = reference_image.shape[-1]
|
@@ -89,16 +111,21 @@ class ACEPlusImageProcessor():
|
|
89 |
rW = int(W * scale) // self.d * self.d
|
90 |
slice_w = int(slice_w * scale) // self.d * self.d
|
91 |
|
92 |
-
edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.
|
93 |
edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
|
98 |
def postprocess(self, image, slice_w, out_w, out_h):
|
99 |
w, h = image.size
|
100 |
if slice_w > 0:
|
101 |
-
output_image = image.crop((slice_w +
|
102 |
output_image = output_image.resize((out_w, out_h))
|
103 |
else:
|
104 |
output_image = image
|
|
|
49 |
edit_mask=None,
|
50 |
height=1024,
|
51 |
width=1024,
|
52 |
+
repainting_scale = 1.0,
|
53 |
+
keep_pixels = False,
|
54 |
+
keep_pixels_rate = 0.8,
|
55 |
+
use_change = False):
|
56 |
reference_image = self.image_check(reference_image)
|
57 |
edit_image = self.image_check(edit_image)
|
58 |
# for reference generation
|
|
|
60 |
edit_image = torch.zeros([3, height, width])
|
61 |
edit_mask = torch.ones([1, height, width])
|
62 |
else:
|
63 |
+
if edit_mask is None:
|
64 |
+
_, eH, eW = edit_image.shape
|
65 |
+
edit_mask = np.ones((eH, eW))
|
66 |
+
else:
|
67 |
+
edit_mask = np.asarray(edit_mask)
|
68 |
+
edit_mask = np.where(edit_mask > 128, 1, 0)
|
69 |
edit_mask = edit_mask.astype(
|
70 |
np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
|
71 |
np.float32)
|
|
|
78 |
|
79 |
assert edit_mask is not None
|
80 |
if reference_image is not None:
|
|
|
81 |
_, H, W = reference_image.shape
|
82 |
_, eH, eW = edit_image.shape
|
83 |
+
if not keep_pixels:
|
84 |
+
# align height with edit_image
|
85 |
+
scale = eH / H
|
86 |
+
tH, tW = eH, int(W * scale)
|
87 |
+
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
|
88 |
+
reference_image)
|
89 |
+
else:
|
90 |
+
# padding
|
91 |
+
if H >= keep_pixels_rate * eH:
|
92 |
+
tH = int(eH * keep_pixels_rate)
|
93 |
+
scale = tH/H
|
94 |
+
tW = int(W * scale)
|
95 |
+
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
|
96 |
+
reference_image)
|
97 |
+
rH, rW = reference_image.shape[-2:]
|
98 |
+
delta_w = 0
|
99 |
+
delta_h = eH - rH
|
100 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
101 |
+
reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image)
|
102 |
edit_image = torch.cat([reference_image, edit_image], dim=-1)
|
103 |
edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
|
104 |
slice_w = reference_image.shape[-1]
|
|
|
111 |
rW = int(W * scale) // self.d * self.d
|
112 |
slice_w = int(slice_w * scale) // self.d * self.d
|
113 |
|
114 |
+
edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image)
|
115 |
edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
|
116 |
+
content_image = edit_image
|
117 |
+
if use_change:
|
118 |
+
change_image = edit_image * edit_mask
|
119 |
+
edit_image = edit_image * (1 - edit_mask)
|
120 |
+
else:
|
121 |
+
change_image = None
|
122 |
+
return edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
|
123 |
|
124 |
|
125 |
def postprocess(self, image, slice_w, out_w, out_h):
|
126 |
w, h = image.size
|
127 |
if slice_w > 0:
|
128 |
+
output_image = image.crop((slice_w + 30, 0, w, h))
|
129 |
output_image = output_image.resize((out_w, out_h))
|
130 |
else:
|
131 |
output_image = image
|