chaojiemao commited on
Commit
a0efccd
·
1 Parent(s): fb6e008

modify inference

Browse files
examples/examples.py CHANGED
@@ -2,9 +2,9 @@ all_examples = [
2
  {
3
  "input_image": None,
4
  "input_mask": None,
5
- "input_reference_image": "assets/samples/portrait/8f13fc996c99688f3af8e2300848a001.jpg",
6
  "save_path": "examples/outputs/portrait_human_1.jpg",
7
- "instruction": "Dress the character in the image with elf ears and a wizard's robe, transforming them into a mage character from a fantasy world.",
8
  "output_h": 1024,
9
  "output_w": 1024,
10
  "seed": 4194866942,
@@ -78,4 +78,156 @@ all_examples = [
78
  "edit_type": "repainting"
79
  }
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  ]
 
2
  {
3
  "input_image": None,
4
  "input_mask": None,
5
+ "input_reference_image": "assets/samples/portrait/human_1.jpg",
6
  "save_path": "examples/outputs/portrait_human_1.jpg",
7
+ "instruction": "Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.",
8
  "output_h": 1024,
9
  "output_w": 1024,
10
  "seed": 4194866942,
 
78
  "edit_type": "repainting"
79
  }
80
 
81
+ ]
82
+
83
+ fft_examples = [
84
+ {
85
+ "input_image": None,
86
+ "input_mask": None,
87
+ "input_reference_image": "./assets/samples/portrait/human_1.jpg",
88
+ "save_path": "examples/outputs/portrait_human_1.jpg",
89
+ "instruction": "Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.",
90
+ "output_h": 1024,
91
+ "output_w": 1024,
92
+ "seed": 10000000,
93
+ "repainting_scale": 1.0,
94
+ "edit_type": "repainting"
95
+ },
96
+ {
97
+ "input_image": None,
98
+ "input_mask": None,
99
+ "input_reference_image": "./assets/samples/subject/subject_1.jpg",
100
+ "save_path": "examples/outputs/subject_subject_1.jpg",
101
+ "instruction": "Display the logo in a minimalist style printed in white on a matte black ceramic coffee mug, alongside a steaming cup of coffee on a cozy cafe table.",
102
+ "output_h": 1024,
103
+ "output_w": 1024,
104
+ "seed": 10000000,
105
+ "repainting_scale": 1.0,
106
+ "edit_type": "repainting"
107
+ },
108
+ {
109
+ "input_image": "./assets/samples/application/photo_editing/1_2_edit.jpg",
110
+ "input_mask": "./assets/samples/application/photo_editing/1_2_m.webp",
111
+ "input_reference_image": "./assets/samples/application/photo_editing/1_ref.png",
112
+ "save_path": "examples/outputs/photo_editing_1.jpg",
113
+ "instruction": "The item is put on the table.",
114
+ "output_h": 1024,
115
+ "output_w": 1024,
116
+ "seed": 8006019,
117
+ "repainting_scale": 1.0,
118
+ "edit_type": "repainting"
119
+ },
120
+ {
121
+ "input_image": "./assets/samples/application/logo_paste/1_1_edit.png",
122
+ "input_mask": "./assets/samples/application/logo_paste/1_1_m.png",
123
+ "input_reference_image": "assets/samples/application/logo_paste/1_ref.png",
124
+ "save_path": "examples/outputs/logo_paste_1.jpg",
125
+ "instruction": "The logo is printed on the headphones.",
126
+ "output_h": 1024,
127
+ "output_w": 1024,
128
+ "seed": 934582264,
129
+ "repainting_scale": 1.0,
130
+ "edit_type": "repainting"
131
+ },
132
+ {
133
+ "input_image": "./assets/samples/application/try_on/1_1_edit.png",
134
+ "input_mask": "./assets/samples/application/try_on/1_1_m.png",
135
+ "input_reference_image": "assets/samples/application/try_on/1_ref.png",
136
+ "save_path": "examples/outputs/try_on_1.jpg",
137
+ "instruction": "The woman dresses this skirt.",
138
+ "output_h": 1024,
139
+ "output_w": 1024,
140
+ "seed": 934582264,
141
+ "repainting_scale": 1.0,
142
+ "edit_type": "repainting"
143
+ },
144
+ {
145
+ "input_image": "./assets/samples/portrait/human_1.jpg",
146
+ "input_mask": "assets/samples/application/movie_poster/1_2_m.webp",
147
+ "input_reference_image": "assets/samples/application/movie_poster/1_ref.png",
148
+ "save_path": "examples/outputs/movie_poster_1.jpg",
149
+ "instruction": "{image}, the man faces the camera.",
150
+ "output_h": 1024,
151
+ "output_w": 1024,
152
+ "seed": 3999647,
153
+ "repainting_scale": 1.0,
154
+ "edit_type": "repainting"
155
+ },
156
+ {
157
+ "input_image": "./assets/samples/application/sr/sr_tiger.png",
158
+ "input_mask": "./assets/samples/application/sr/sr_tiger_m.webp",
159
+ "input_reference_image": None,
160
+ "save_path": "examples/outputs/mario_recolorizing_1.jpg",
161
+ "instruction": "{image} features a close-up of a young, furry tiger cub on a rock. The tiger, which appears to be quite young, has distinctive orange, "
162
+ "black, and white striped fur, typical of tigers. The cub's eyes have a bright and curious expression, and its ears are perked up, "
163
+ "indicating alertness. The cub seems to be in the act of climbing or resting on the rock. The background is a blurred grassland with trees, "
164
+ "but the focus is on the cub, which is vividly colored while the rest of the image is in grayscale, drawing attention to the tiger's details."
165
+ " The photo captures a moment in the wild, depicting the charming and tenacious nature of this young tiger,"
166
+ " as well as its typical interaction with the environment.",
167
+ "output_h": 1024,
168
+ "output_w": 1024,
169
+ "seed": 199999,
170
+ "repainting_scale": 0.0,
171
+ "edit_type": "no_preprocess"
172
+ },
173
+ {
174
+ "input_image": "./assets/samples/application/photo_editing/1_ref.png",
175
+ "input_mask": "./assets/samples/application/photo_editing/1_1_orm.webp",
176
+ "input_reference_image": None,
177
+ "save_path": "examples/outputs/mario_repainting_1.jpg",
178
+ "instruction": "a blue hand",
179
+ "output_h": 1024,
180
+ "output_w": 1024,
181
+ "seed": 63401,
182
+ "repainting_scale": 1.0,
183
+ "edit_type": "repainting"
184
+ },
185
+ {
186
+ "input_image": "./assets/samples/application/photo_editing/1_ref.png",
187
+ "input_mask": "./assets/samples/application/photo_editing/1_1_rm.webp",
188
+ "input_reference_image": None,
189
+ "save_path": "examples/outputs/mario_repainting_2.jpg",
190
+ "instruction": "Mechanical hands like a robot",
191
+ "output_h": 1024,
192
+ "output_w": 1024,
193
+ "seed": 59107,
194
+ "repainting_scale": 1.0,
195
+ "edit_type": "repainting"
196
+ },
197
+ {
198
+ "input_image": "./assets/samples/control/1_1.webp",
199
+ "input_mask": "./assets/samples/control/1_1_m.webp",
200
+ "input_reference_image": None,
201
+ "save_path": "examples/outputs/control_recolorizing.jpg",
202
+ "instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
203
+ "output_h": 1024,
204
+ "output_w": 1024,
205
+ "seed": 9652101,
206
+ "repainting_scale": 0.0,
207
+ "edit_type": "recolorizing"
208
+ },
209
+ {
210
+ "input_image": "./assets/samples/control/1_1.webp",
211
+ "input_mask": "./assets/samples/control/1_1_m.webp",
212
+ "input_reference_image": None,
213
+ "save_path": "examples/outputs/control_depth.jpg",
214
+ "instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
215
+ "output_h": 1024,
216
+ "output_w": 1024,
217
+ "seed": 14979476,
218
+ "repainting_scale": 0.0,
219
+ "edit_type": "depth_repainting"
220
+ },
221
+ {
222
+ "input_image": "./assets/samples/control/1_1.webp",
223
+ "input_mask": "./assets/samples/control/1_1_m.webp",
224
+ "input_reference_image": None,
225
+ "save_path": "examples/outputs/control_contour.jpg",
226
+ "instruction": "{image} Beautiful female portrait, Robot with smooth White transparent carbon shell, rococo detailing, Natural lighting, Highly detailed, Cinematic, 4K.",
227
+ "output_h": 1024,
228
+ "output_w": 1024,
229
+ "seed": 4227292472,
230
+ "repainting_scale": 0.0,
231
+ "edit_type": "contour_repainting"
232
+ }
233
  ]
infer_fft.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import glob
5
+ import importlib
6
+ import io
7
+ import os
8
+ import sys
9
+
10
+ from PIL import Image
11
+ from scepter.modules.transform.io import pillow_convert
12
+ from scepter.modules.utils.config import Config
13
+ from scepter.modules.utils.file_system import FS
14
+
15
+ if os.path.exists('__init__.py'):
16
+ package_name = 'scepter_ext'
17
+ spec = importlib.util.spec_from_file_location(package_name, '__init__.py')
18
+ package = importlib.util.module_from_spec(spec)
19
+ sys.modules[package_name] = package
20
+ spec.loader.exec_module(package)
21
+
22
+ from examples.examples import fft_examples as all_examples
23
+ from inference.registry import INFERENCES
24
+ fs_list = [
25
+ Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
26
+ Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False),
27
+ Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False),
28
+ Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False),
29
+ ]
30
+
31
+ for one_fs in fs_list:
32
+ FS.init_fs_client(one_fs)
33
+
34
+
35
+ def run_one_case(pipe,
36
+ input_image = None,
37
+ input_mask = None,
38
+ input_reference_image = None,
39
+ save_path = "examples/output/example.png",
40
+ instruction = "",
41
+ output_h = 1024,
42
+ output_w = 1024,
43
+ seed = -1,
44
+ sample_steps = None,
45
+ guide_scale = None,
46
+ repainting_scale = None,
47
+ use_change=True,
48
+ keep_pixels=True,
49
+ keep_pixels_rate=0.8,
50
+ **kwargs):
51
+ if input_image is not None:
52
+ input_image = Image.open(io.BytesIO(FS.get_object(input_image)))
53
+ input_image = pillow_convert(input_image, "RGB")
54
+ if input_mask is not None:
55
+ input_mask = Image.open(io.BytesIO(FS.get_object(input_mask)))
56
+ input_mask = pillow_convert(input_mask, "L")
57
+ if input_reference_image is not None:
58
+ input_reference_image = Image.open(io.BytesIO(FS.get_object(input_reference_image)))
59
+ input_reference_image = pillow_convert(input_reference_image, "RGB")
60
+ print(repainting_scale)
61
+ image, _, _, _, seed = pipe(
62
+ reference_image=input_reference_image,
63
+ edit_image=input_image,
64
+ edit_mask=input_mask,
65
+ prompt=instruction,
66
+ output_height=output_h,
67
+ output_width=output_w,
68
+ sampler='flow_euler',
69
+ sample_steps=sample_steps or pipe.input.get("sample_steps", 28),
70
+ guide_scale=guide_scale or pipe.input.get("guide_scale", 50),
71
+ seed=seed,
72
+ repainting_scale=repainting_scale,
73
+ use_change=use_change,
74
+ keep_pixels=keep_pixels,
75
+ keep_pixels_rate=keep_pixels_rate
76
+ )
77
+ with FS.put_to(save_path) as local_path:
78
+ image.save(local_path)
79
+ return local_path, seed
80
+
81
+
82
+ def run():
83
+ parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
84
+ parser.add_argument('--instruction',
85
+ dest='instruction',
86
+ help='The instruction for editing or generating!',
87
+ default="")
88
+ parser.add_argument('--output_h',
89
+ dest='output_h',
90
+ help='The height of output image for generation tasks!',
91
+ type=int,
92
+ default=1024)
93
+ parser.add_argument('--output_w',
94
+ dest='output_w',
95
+ help='The width of output image for generation tasks!',
96
+ type=int,
97
+ default=1024)
98
+ parser.add_argument('--input_reference_image',
99
+ dest='input_reference_image',
100
+ help='The input reference image!',
101
+ default=None
102
+ )
103
+ parser.add_argument('--input_image',
104
+ dest='input_image',
105
+ help='The input image!',
106
+ default=None
107
+ )
108
+ parser.add_argument('--input_mask',
109
+ dest='input_mask',
110
+ help='The input mask!',
111
+ default=None
112
+ )
113
+ parser.add_argument('--save_path',
114
+ dest='save_path',
115
+ help='The save path for output image!',
116
+ default='examples/output_images/output.png'
117
+ )
118
+ parser.add_argument('--seed',
119
+ dest='seed',
120
+ help='The seed for generation!',
121
+ type=int,
122
+ default=-1)
123
+
124
+ parser.add_argument('--step',
125
+ dest='step',
126
+ help='The sample step for generation!',
127
+ type=int,
128
+ default=None)
129
+
130
+ parser.add_argument('--guide_scale',
131
+ dest='guide_scale',
132
+ help='The guide scale for generation!',
133
+ type=int,
134
+ default=None)
135
+
136
+ parser.add_argument('--repainting_scale',
137
+ dest='repainting_scale',
138
+ help='The repainting scale for content filling generation!',
139
+ type=int,
140
+ default=None)
141
+
142
+ cfg = Config(load=True, parser_ins=parser)
143
+ model_cfg = Config(load=True, cfg_file="config/ace_plus_fft.yaml")
144
+ pipe = INFERENCES.build(model_cfg)
145
+
146
+
147
+ if cfg.args.instruction == "" and cfg.args.input_image is None and cfg.args.input_reference_image is None:
148
+ params = {
149
+ "output_h": cfg.args.output_h,
150
+ "output_w": cfg.args.output_w,
151
+ "sample_steps": cfg.args.step,
152
+ "guide_scale": cfg.args.guide_scale
153
+ }
154
+ # run examples
155
+
156
+ for example in all_examples:
157
+ example.update(params)
158
+ local_path, seed = run_one_case(pipe, **example)
159
+
160
+ else:
161
+ params = {
162
+ "input_image": cfg.args.input_image,
163
+ "input_mask": cfg.args.input_mask,
164
+ "input_reference_image": cfg.args.input_reference_image,
165
+ "save_path": cfg.args.save_path,
166
+ "instruction": cfg.args.instruction,
167
+ "output_h": cfg.args.output_h,
168
+ "output_w": cfg.args.output_w,
169
+ "sample_steps": cfg.args.step,
170
+ "guide_scale": cfg.args.guide_scale,
171
+ "repainting_scale": cfg.args.repainting_scale,
172
+ }
173
+ local_path, seed = run_one_case(pipe, **params)
174
+ print(local_path, seed)
175
+
176
+ if __name__ == '__main__':
177
+ run()
178
+
infer_lora.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import glob
5
+ import io
6
+ import os
7
+
8
+ from PIL import Image
9
+ from scepter.modules.transform.io import pillow_convert
10
+ from scepter.modules.utils.config import Config
11
+ from scepter.modules.utils.file_system import FS
12
+
13
+ from examples.examples import all_examples
14
+ from inference.ace_plus_diffusers import ACEPlusDiffuserInference
15
+ inference_dict = {
16
+ "ACE_DIFFUSER_PLUS": ACEPlusDiffuserInference
17
+ }
18
+
19
+ fs_list = [
20
+ Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
21
+ Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False),
22
+ Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False),
23
+ Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False),
24
+ ]
25
+
26
+ for one_fs in fs_list:
27
+ FS.init_fs_client(one_fs)
28
+
29
+
30
+ def run_one_case(pipe,
31
+ input_image = None,
32
+ input_mask = None,
33
+ input_reference_image = None,
34
+ save_path = "examples/output/example.png",
35
+ instruction = "",
36
+ output_h = 1024,
37
+ output_w = 1024,
38
+ seed = -1,
39
+ sample_steps = None,
40
+ guide_scale = None,
41
+ repainting_scale = None,
42
+ model_path = None,
43
+ **kwargs):
44
+ if input_image is not None:
45
+ input_image = Image.open(io.BytesIO(FS.get_object(input_image)))
46
+ input_image = pillow_convert(input_image, "RGB")
47
+ if input_mask is not None:
48
+ input_mask = Image.open(io.BytesIO(FS.get_object(input_mask)))
49
+ input_mask = pillow_convert(input_mask, "L")
50
+ if input_reference_image is not None:
51
+ input_reference_image = Image.open(io.BytesIO(FS.get_object(input_reference_image)))
52
+ input_reference_image = pillow_convert(input_reference_image, "RGB")
53
+
54
+ image, seed = pipe(
55
+ reference_image=input_reference_image,
56
+ edit_image=input_image,
57
+ edit_mask=input_mask,
58
+ prompt=instruction,
59
+ output_height=output_h,
60
+ output_width=output_w,
61
+ sampler='flow_euler',
62
+ sample_steps=sample_steps or pipe.input.get("sample_steps", 28),
63
+ guide_scale=guide_scale or pipe.input.get("guide_scale", 50),
64
+ seed=seed,
65
+ repainting_scale=repainting_scale or pipe.input.get("repainting_scale", 1.0),
66
+ lora_path = model_path
67
+ )
68
+ with FS.put_to(save_path) as local_path:
69
+ image.save(local_path)
70
+ return local_path, seed
71
+
72
+
73
+ def run():
74
+ parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
75
+ parser.add_argument('--instruction',
76
+ dest='instruction',
77
+ help='The instruction for editing or generating!',
78
+ default="")
79
+ parser.add_argument('--output_h',
80
+ dest='output_h',
81
+ help='The height of output image for generation tasks!',
82
+ type=int,
83
+ default=1024)
84
+ parser.add_argument('--output_w',
85
+ dest='output_w',
86
+ help='The width of output image for generation tasks!',
87
+ type=int,
88
+ default=1024)
89
+ parser.add_argument('--input_reference_image',
90
+ dest='input_reference_image',
91
+ help='The input reference image!',
92
+ default=None
93
+ )
94
+ parser.add_argument('--input_image',
95
+ dest='input_image',
96
+ help='The input image!',
97
+ default=None
98
+ )
99
+ parser.add_argument('--input_mask',
100
+ dest='input_mask',
101
+ help='The input mask!',
102
+ default=None
103
+ )
104
+ parser.add_argument('--save_path',
105
+ dest='save_path',
106
+ help='The save path for output image!',
107
+ default='examples/output_images/output.png'
108
+ )
109
+ parser.add_argument('--seed',
110
+ dest='seed',
111
+ help='The seed for generation!',
112
+ type=int,
113
+ default=-1)
114
+
115
+ parser.add_argument('--step',
116
+ dest='step',
117
+ help='The sample step for generation!',
118
+ type=int,
119
+ default=None)
120
+
121
+ parser.add_argument('--guide_scale',
122
+ dest='guide_scale',
123
+ help='The guide scale for generation!',
124
+ type=int,
125
+ default=None)
126
+
127
+ parser.add_argument('--repainting_scale',
128
+ dest='repainting_scale',
129
+ help='The repainting scale for content filling generation!',
130
+ type=int,
131
+ default=None)
132
+
133
+ parser.add_argument('--task_type',
134
+ dest='task_type',
135
+ choices=['portrait', 'subject', 'local_editing'],
136
+ help="Choose the task type.",
137
+ default='')
138
+
139
+ parser.add_argument('--task_model',
140
+ dest='task_model',
141
+ help='The models list for different tasks!',
142
+ default="./models/model_zoo.yaml")
143
+
144
+
145
+ parser.add_argument('--infer_type',
146
+ dest='infer_type',
147
+ choices=['diffusers'],
148
+ default='diffusers',
149
+ help="Choose the inference scripts. 'native' refers to using the official implementation of ace++, "
150
+ "while 'diffusers' refers to using the adaptation for diffusers")
151
+
152
+ parser.add_argument('--cfg_folder',
153
+ dest='cfg_folder',
154
+ help='The inference config!',
155
+ default="./config")
156
+
157
+ cfg = Config(load=True, parser_ins=parser)
158
+
159
+ model_yamls = glob.glob(os.path.join(cfg.args.cfg_folder, '*.yaml'))
160
+ model_choices = dict()
161
+ for i in model_yamls:
162
+ model_cfg = Config(load=True, cfg_file=i)
163
+ model_name = model_cfg.NAME
164
+ model_choices[model_name] = model_cfg
165
+
166
+ if cfg.args.infer_type == "native":
167
+ infer_name = "ace_plus_native_infer"
168
+ elif cfg.args.infer_type == "diffusers":
169
+ infer_name = "ace_plus_diffuser_infer"
170
+ else:
171
+ raise ValueError("infer_type should be native or diffusers")
172
+
173
+ assert infer_name in model_choices
174
+
175
+ # choose different model
176
+ task_model_cfg = Config(load=True, cfg_file=cfg.args.task_model)
177
+
178
+ task_model_dict = {}
179
+ for task_name, task_model in task_model_cfg.MODEL.items():
180
+ task_model_dict[task_name] = task_model
181
+
182
+
183
+ # choose the inference scripts.
184
+ pipe_cfg = model_choices[infer_name]
185
+ infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE_PLUS")
186
+ pipe = inference_dict[infer_name]()
187
+ pipe.init_from_cfg(pipe_cfg)
188
+
189
+ if cfg.args.instruction == "" and cfg.args.input_image is None and cfg.args.input_reference_image is None:
190
+ params = {
191
+ "output_h": cfg.args.output_h,
192
+ "output_w": cfg.args.output_w,
193
+ "sample_steps": cfg.args.step,
194
+ "guide_scale": cfg.args.guide_scale
195
+ }
196
+ # run examples
197
+
198
+ for example in all_examples:
199
+ example["model_path"] = FS.get_from(task_model_dict[example["task_type"].upper()]["MODEL_PATH"])
200
+ example.update(params)
201
+ if example["edit_type"] == "repainting":
202
+ example["repainting_scale"] = 1.0
203
+ else:
204
+ example["repainting_scale"] = task_model_dict[example["task_type"].upper()].get("REPAINTING_SCALE", 1.0)
205
+ print(example)
206
+ local_path, seed = run_one_case(pipe, **example)
207
+
208
+ else:
209
+ assert cfg.args.task_type.upper() in task_model_cfg
210
+ params = {
211
+ "input_image": cfg.args.input_image,
212
+ "input_mask": cfg.args.input_mask,
213
+ "input_reference_image": cfg.args.input_reference_image,
214
+ "save_path": cfg.args.save_path,
215
+ "instruction": cfg.args.instruction,
216
+ "output_h": cfg.args.output_h,
217
+ "output_w": cfg.args.output_w,
218
+ "sample_steps": cfg.args.step,
219
+ "guide_scale": cfg.args.guide_scale,
220
+ "repainting_scale": cfg.args.repainting_scale,
221
+ "model_path": FS.get_from(task_model_dict[cfg.args.task_type.upper()]["MODEL_PATH"])
222
+ }
223
+ local_path, seed = run_one_case(pipe, **params)
224
+ print(local_path, seed)
225
+
226
+ if __name__ == '__main__':
227
+ run()
228
+
inference/__init__.py CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .ace_plus_diffusers import ACEPlusDiffuserInference
2
+ from .ace_plus_inference import ACEInference
inference/ace_plus_diffusers.py CHANGED
@@ -12,7 +12,6 @@ from scepter.modules.utils.logger import get_logger
12
  from transformers import T5TokenizerFast
13
  from .utils import ACEPlusImageProcessor
14
 
15
-
16
  class ACEPlusDiffuserInference():
17
  def __init__(self, logger=None):
18
  if logger is None:
@@ -39,7 +38,6 @@ class ACEPlusDiffuserInference():
39
  self.pipe.tokenizer_2 = tokenizer_2
40
  self.load_default(cfg.DEFAULT_PARAS)
41
 
42
-
43
  def prepare_input(self,
44
  image,
45
  mask,
@@ -88,7 +86,11 @@ class ACEPlusDiffuserInference():
88
  if isinstance(prompt, str):
89
  prompt = [prompt]
90
  seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
91
- image, mask, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask, repainting_scale = repainting_scale)
 
 
 
 
92
  h, w = image.shape[1:]
93
  generator = torch.Generator("cpu").manual_seed(seed)
94
  masked_image_latents = self.prepare_input(image, mask,
@@ -98,6 +100,8 @@ class ACEPlusDiffuserInference():
98
  with FS.get_from(lora_path) as local_path:
99
  self.pipe.load_lora_weights(local_path)
100
 
 
 
101
  image = self.pipe(
102
  prompt=prompt,
103
  masked_image_latents=masked_image_latents,
 
12
  from transformers import T5TokenizerFast
13
  from .utils import ACEPlusImageProcessor
14
 
 
15
  class ACEPlusDiffuserInference():
16
  def __init__(self, logger=None):
17
  if logger is None:
 
38
  self.pipe.tokenizer_2 = tokenizer_2
39
  self.load_default(cfg.DEFAULT_PARAS)
40
 
 
41
  def prepare_input(self,
42
  image,
43
  mask,
 
86
  if isinstance(prompt, str):
87
  prompt = [prompt]
88
  seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
89
+ # edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
90
+ image, mask, _, _, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
91
+ width = output_width,
92
+ height = output_height,
93
+ repainting_scale = repainting_scale)
94
  h, w = image.shape[1:]
95
  generator = torch.Generator("cpu").manual_seed(seed)
96
  masked_image_latents = self.prepare_input(image, mask,
 
100
  with FS.get_from(lora_path) as local_path:
101
  self.pipe.load_lora_weights(local_path)
102
 
103
+
104
+
105
  image = self.pipe(
106
  prompt=prompt,
107
  masked_image_latents=masked_image_latents,
inference/ace_plus_inference.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import random
4
+ from collections import OrderedDict
5
+
6
+ import torch, numpy as np
7
+ from PIL import Image
8
+ from scepter.modules.model.registry import MODELS
9
+ from scepter.modules.utils.config import Config
10
+ from scepter.modules.utils.distribute import we
11
+ from .registry import BaseInference, INFERENCES
12
+ from .utils import ACEPlusImageProcessor
13
+
14
+ @INFERENCES.register_class()
15
+ class ACEInference(BaseInference):
16
+ '''
17
+ reuse the ldm code
18
+ '''
19
+ def __init__(self, cfg, logger=None):
20
+ super().__init__(cfg, logger)
21
+ self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id)
22
+ self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN)
23
+ self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for
24
+ k, v in cfg.SAMPLE_ARGS.items()}
25
+ self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16"))
26
+ @torch.no_grad()
27
+ def __call__(self,
28
+ reference_image=None,
29
+ edit_image=None,
30
+ edit_mask=None,
31
+ prompt='',
32
+ edit_type=None,
33
+ output_height=1024,
34
+ output_width=1024,
35
+ sampler='flow_euler',
36
+ sample_steps=28,
37
+ guide_scale=50,
38
+ lora_path=None,
39
+ seed=-1,
40
+ repainting_scale=0,
41
+ use_change=False,
42
+ keep_pixels=False,
43
+ keep_pixels_rate=0.8,
44
+ **kwargs):
45
+ # convert the input info to the input of ldm.
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+ seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1)
49
+ image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
50
+ height=output_height, width=output_width,
51
+ repainting_scale=repainting_scale,
52
+ keep_pixels=keep_pixels,
53
+ keep_pixels_rate=keep_pixels_rate,
54
+ use_change = use_change)
55
+ change_image = [None] if change_image is None else [change_image.to(we.device_id)]
56
+ image, mask = [image.to(we.device_id)], [mask.to(we.device_id)]
57
+
58
+ (src_image_list, src_mask_list, modify_image_list,
59
+ edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt]
60
+
61
+ with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'):
62
+ out_image = self.pipe(
63
+ src_image_list=src_image_list,
64
+ modify_image_list= modify_image_list,
65
+ src_mask_list=src_mask_list,
66
+ edit_id=edit_id,
67
+ image=image,
68
+ image_mask=mask,
69
+ prompt=prompt,
70
+ sampler='flow_euler',
71
+ sample_steps=sample_steps,
72
+ seed=seed,
73
+ guide_scale=guide_scale,
74
+ show_process=True,
75
+ )
76
+ imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy()
77
+ for x_i in out_image
78
+ ]
79
+ imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
80
+ edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
81
+ change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
82
+ mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8))
83
+ return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed
inference/registry.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import torch
5
+ from PIL.Image import Image
6
+ from collections import OrderedDict
7
+ from scepter.modules.utils.distribute import we
8
+ from scepter.modules.utils.config import Config
9
+ from scepter.modules.utils.logger import get_logger
10
+ from scepter.studio.utils.env import get_available_memory
11
+ from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS
12
+ from scepter.modules.utils.registry import Registry, build_from_config
13
+ def get_model(model_tuple):
14
+ assert 'model' in model_tuple
15
+ return model_tuple['model']
16
+
17
+ class BaseInference():
18
+ '''
19
+ support to load the components dynamicly.
20
+ create and load model when run this model at the first time.
21
+ '''
22
+ def __init__(self, cfg, logger=None):
23
+ if logger is None:
24
+ logger = get_logger(name='scepter')
25
+ self.logger = logger
26
+ self.name = cfg.NAME
27
+
28
+ def init_from_modules(self, modules):
29
+ for k, v in modules.items():
30
+ self.__setattr__(k, v)
31
+
32
+ def infer_model(self, cfg, module_paras=None):
33
+ module = {
34
+ 'model': None,
35
+ 'cfg': cfg,
36
+ 'device': 'offline',
37
+ 'name': cfg.NAME,
38
+ 'function_info': {},
39
+ 'paras': {}
40
+ }
41
+ if module_paras is None:
42
+ return module
43
+ function_info = {}
44
+ paras = {
45
+ k.lower(): v
46
+ for k, v in module_paras.get('PARAS', {}).items()
47
+ }
48
+ for function in module_paras.get('FUNCTION', []):
49
+ input_dict = {}
50
+ for inp in function.get('INPUT', []):
51
+ if inp.lower() in self.input:
52
+ input_dict[inp.lower()] = self.input[inp.lower()]
53
+ function_info[function.NAME] = {
54
+ 'dtype': function.get('DTYPE', 'float32'),
55
+ 'input': input_dict
56
+ }
57
+ module['paras'] = paras
58
+ module['function_info'] = function_info
59
+ return module
60
+
61
+ def init_from_ckpt(self, path, model, ignore_keys=list()):
62
+ if path.endswith('safetensors'):
63
+ from safetensors.torch import load_file as load_safetensors
64
+ sd = load_safetensors(path)
65
+ else:
66
+ sd = torch.load(path, map_location='cpu', weights_only=True)
67
+
68
+ new_sd = OrderedDict()
69
+ for k, v in sd.items():
70
+ ignored = False
71
+ for ik in ignore_keys:
72
+ if ik in k:
73
+ if we.rank == 0:
74
+ self.logger.info(
75
+ 'Ignore key {} from state_dict.'.format(k))
76
+ ignored = True
77
+ break
78
+ if not ignored:
79
+ new_sd[k] = v
80
+
81
+ missing, unexpected = model.load_state_dict(new_sd, strict=False)
82
+ if we.rank == 0:
83
+ self.logger.info(
84
+ f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
85
+ )
86
+ if len(missing) > 0:
87
+ self.logger.info(f'Missing Keys:\n {missing}')
88
+ if len(unexpected) > 0:
89
+ self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
90
+
91
+ def load(self, module):
92
+ if module['device'] == 'offline':
93
+ from scepter.modules.utils.import_utils import LazyImportModule
94
+ if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or
95
+ module['cfg'].NAME in MODELS.class_map):
96
+ model = MODELS.build(module['cfg'], logger=self.logger).eval()
97
+ elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or
98
+ module['cfg'].NAME in BACKBONES.class_map):
99
+ model = BACKBONES.build(module['cfg'],
100
+ logger=self.logger).eval()
101
+ elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or
102
+ module['cfg'].NAME in EMBEDDERS.class_map):
103
+ model = EMBEDDERS.build(module['cfg'],
104
+ logger=self.logger).eval()
105
+ else:
106
+ raise NotImplementedError
107
+ if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None:
108
+ model = model.to(getattr(torch, module['cfg'].DTYPE))
109
+ if module['cfg'].get('RELOAD_MODEL', None):
110
+ self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model)
111
+ module['model'] = model
112
+ module['device'] = 'cpu'
113
+ if module['device'] == 'cpu':
114
+ module['device'] = we.device_id
115
+ module['model'] = module['model'].to(we.device_id)
116
+ return module
117
+
118
+ def unload(self, module):
119
+ if module is None:
120
+ return module
121
+ mem = get_available_memory()
122
+ free_mem = int(mem['available'] / (1024**2))
123
+ total_mem = int(mem['total'] / (1024**2))
124
+ if free_mem < 0.5 * total_mem:
125
+ if module['model'] is not None:
126
+ module['model'] = module['model'].to('cpu')
127
+ del module['model']
128
+ module['model'] = None
129
+ module['device'] = 'offline'
130
+ print('delete module')
131
+ else:
132
+ if module['model'] is not None:
133
+ module['model'] = module['model'].to('cpu')
134
+ module['device'] = 'cpu'
135
+ else:
136
+ module['device'] = 'offline'
137
+ if torch.cuda.is_available():
138
+ torch.cuda.empty_cache()
139
+ torch.cuda.ipc_collect()
140
+ return module
141
+
142
+ def dynamic_load(self, module=None, name=''):
143
+ self.logger.info('Loading {} model'.format(name))
144
+ if name == 'all':
145
+ for subname in self.loaded_model_name:
146
+ self.loaded_model[subname] = self.dynamic_load(
147
+ getattr(self, subname), subname)
148
+ elif name in self.loaded_model_name:
149
+ if name in self.loaded_model:
150
+ if module['cfg'] != self.loaded_model[name]['cfg']:
151
+ self.unload(self.loaded_model[name])
152
+ module = self.load(module)
153
+ self.loaded_model[name] = module
154
+ return module
155
+ elif module['device'] == 'cpu' or module['device'] == 'offline':
156
+ module = self.load(module)
157
+ return module
158
+ else:
159
+ return module
160
+ else:
161
+ module = self.load(module)
162
+ self.loaded_model[name] = module
163
+ return module
164
+ else:
165
+ return self.load(module)
166
+
167
+ def dynamic_unload(self, module=None, name='', skip_loaded=False):
168
+ self.logger.info('Unloading {} model'.format(name))
169
+ if name == 'all':
170
+ for name, module in self.loaded_model.items():
171
+ module = self.unload(self.loaded_model[name])
172
+ self.loaded_model[name] = module
173
+ elif name in self.loaded_model_name:
174
+ if name in self.loaded_model:
175
+ if not skip_loaded:
176
+ module = self.unload(self.loaded_model[name])
177
+ self.loaded_model[name] = module
178
+ else:
179
+ self.unload(module)
180
+ else:
181
+ self.unload(module)
182
+
183
+ def load_default(self, cfg):
184
+ module_paras = {}
185
+ if cfg is not None:
186
+ self.paras = cfg.PARAS
187
+ self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
188
+ self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
189
+ self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
190
+ module_paras = cfg.MODULES_PARAS
191
+ return module_paras
192
+
193
+ def load_image(self, image, num_samples=1):
194
+ if isinstance(image, torch.Tensor):
195
+ pass
196
+ elif isinstance(image, Image):
197
+ pass
198
+ elif isinstance(image, Image):
199
+ pass
200
+
201
+ def get_function_info(self, module, function_name=None):
202
+ all_function = module['function_info']
203
+ if function_name in all_function:
204
+ return function_name, all_function[function_name]['dtype']
205
+ if function_name is None and len(all_function) == 1:
206
+ for k, v in all_function.items():
207
+ return k, v['dtype']
208
+
209
+ @torch.no_grad()
210
+ def __call__(self,
211
+ input,
212
+ **kwargs):
213
+ return
214
+
215
+ def build_inference(cfg, registry, logger=None, *args, **kwargs):
216
+ """ After build model, load pretrained model if exists key `pretrain`.
217
+
218
+ pretrain (str, dict): Describes how to load pretrained model.
219
+ str, treat pretrain as model path;
220
+ dict: should contains key `path`, and other parameters token by function load_pretrained();
221
+ """
222
+ if not isinstance(cfg, Config):
223
+ raise TypeError(f'Config must be type dict, got {type(cfg)}')
224
+ model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
225
+ return model
226
+
227
+ # reigister cls for diffusion.
228
+ INFERENCES = Registry('INFERENCE', build_func=build_inference)
inference/utils.py CHANGED
@@ -49,7 +49,10 @@ class ACEPlusImageProcessor():
49
  edit_mask=None,
50
  height=1024,
51
  width=1024,
52
- repainting_scale = 1.0):
 
 
 
53
  reference_image = self.image_check(reference_image)
54
  edit_image = self.image_check(edit_image)
55
  # for reference generation
@@ -57,8 +60,12 @@ class ACEPlusImageProcessor():
57
  edit_image = torch.zeros([3, height, width])
58
  edit_mask = torch.ones([1, height, width])
59
  else:
60
- edit_mask = np.asarray(edit_mask)
61
- edit_mask = np.where(edit_mask > 128, 1, 0)
 
 
 
 
62
  edit_mask = edit_mask.astype(
63
  np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
64
  np.float32)
@@ -71,12 +78,27 @@ class ACEPlusImageProcessor():
71
 
72
  assert edit_mask is not None
73
  if reference_image is not None:
74
- # align height with edit_image
75
  _, H, W = reference_image.shape
76
  _, eH, eW = edit_image.shape
77
- scale = eH / H
78
- tH, tW = eH, int(W * scale)
79
- reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(reference_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  edit_image = torch.cat([reference_image, edit_image], dim=-1)
81
  edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
82
  slice_w = reference_image.shape[-1]
@@ -89,16 +111,21 @@ class ACEPlusImageProcessor():
89
  rW = int(W * scale) // self.d * self.d
90
  slice_w = int(slice_w * scale) // self.d * self.d
91
 
92
- edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(edit_image)
93
  edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
94
-
95
- return edit_image, edit_mask, out_h, out_w, slice_w
 
 
 
 
 
96
 
97
 
98
  def postprocess(self, image, slice_w, out_w, out_h):
99
  w, h = image.size
100
  if slice_w > 0:
101
- output_image = image.crop((slice_w + 20, 0, w, h))
102
  output_image = output_image.resize((out_w, out_h))
103
  else:
104
  output_image = image
 
49
  edit_mask=None,
50
  height=1024,
51
  width=1024,
52
+ repainting_scale = 1.0,
53
+ keep_pixels = False,
54
+ keep_pixels_rate = 0.8,
55
+ use_change = False):
56
  reference_image = self.image_check(reference_image)
57
  edit_image = self.image_check(edit_image)
58
  # for reference generation
 
60
  edit_image = torch.zeros([3, height, width])
61
  edit_mask = torch.ones([1, height, width])
62
  else:
63
+ if edit_mask is None:
64
+ _, eH, eW = edit_image.shape
65
+ edit_mask = np.ones((eH, eW))
66
+ else:
67
+ edit_mask = np.asarray(edit_mask)
68
+ edit_mask = np.where(edit_mask > 128, 1, 0)
69
  edit_mask = edit_mask.astype(
70
  np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
71
  np.float32)
 
78
 
79
  assert edit_mask is not None
80
  if reference_image is not None:
 
81
  _, H, W = reference_image.shape
82
  _, eH, eW = edit_image.shape
83
+ if not keep_pixels:
84
+ # align height with edit_image
85
+ scale = eH / H
86
+ tH, tW = eH, int(W * scale)
87
+ reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
88
+ reference_image)
89
+ else:
90
+ # padding
91
+ if H >= keep_pixels_rate * eH:
92
+ tH = int(eH * keep_pixels_rate)
93
+ scale = tH/H
94
+ tW = int(W * scale)
95
+ reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
96
+ reference_image)
97
+ rH, rW = reference_image.shape[-2:]
98
+ delta_w = 0
99
+ delta_h = eH - rH
100
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
101
+ reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image)
102
  edit_image = torch.cat([reference_image, edit_image], dim=-1)
103
  edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
104
  slice_w = reference_image.shape[-1]
 
111
  rW = int(W * scale) // self.d * self.d
112
  slice_w = int(slice_w * scale) // self.d * self.d
113
 
114
+ edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image)
115
  edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
116
+ content_image = edit_image
117
+ if use_change:
118
+ change_image = edit_image * edit_mask
119
+ edit_image = edit_image * (1 - edit_mask)
120
+ else:
121
+ change_image = None
122
+ return edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
123
 
124
 
125
  def postprocess(self, image, slice_w, out_w, out_h):
126
  w, h = image.size
127
  if slice_w > 0:
128
+ output_image = image.crop((slice_w + 30, 0, w, h))
129
  output_image = output_image.resize((out_w, out_h))
130
  else:
131
  output_image = image