Stephen commited on
Commit
4bb68d5
·
1 Parent(s): 6d32b43
Files changed (6) hide show
  1. .gitignore +162 -0
  2. app.py +465 -0
  3. briarmbg.py +462 -0
  4. db_examples.py +3 -0
  5. models/model_download_here +0 -0
  6. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.safetensors
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ .idea/
app.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import safetensors.torch as sf
7
+ import db_examples
8
+
9
+ from PIL import Image
10
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
11
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
12
+ from diffusers.models.attention_processor import AttnProcessor2_0
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+ from briarmbg import BriaRMBG
15
+ from enum import Enum
16
+ from torch.hub import download_url_to_file
17
+
18
+
19
+ # 'stablediffusionapi/realistic-vision-v51'
20
+ # 'runwayml/stable-diffusion-v1-5'
21
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
22
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
23
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
24
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
25
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
26
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
27
+
28
+ # Change UNet
29
+
30
+ with torch.no_grad():
31
+ new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
32
+ new_conv_in.weight.zero_()
33
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
34
+ new_conv_in.bias = unet.conv_in.bias
35
+ unet.conv_in = new_conv_in
36
+
37
+ unet_original_forward = unet.forward
38
+
39
+
40
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
41
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
42
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
43
+ new_sample = torch.cat([sample, c_concat], dim=1)
44
+ kwargs['cross_attention_kwargs'] = {}
45
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
46
+
47
+
48
+ unet.forward = hooked_unet_forward
49
+
50
+ # Load
51
+
52
+ model_path = './models/iclight_sd15_fbc.safetensors'
53
+
54
+ if not os.path.exists(model_path):
55
+ download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors', dst=model_path)
56
+
57
+ sd_offset = sf.load_file(model_path)
58
+ sd_origin = unet.state_dict()
59
+ keys = sd_origin.keys()
60
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
61
+ unet.load_state_dict(sd_merged, strict=True)
62
+ del sd_offset, sd_origin, sd_merged, keys
63
+
64
+ # Device
65
+
66
+ device = torch.device('cuda')
67
+ text_encoder = text_encoder.to(device=device, dtype=torch.float16)
68
+ vae = vae.to(device=device, dtype=torch.bfloat16)
69
+ unet = unet.to(device=device, dtype=torch.float16)
70
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
71
+
72
+ # SDP
73
+
74
+ unet.set_attn_processor(AttnProcessor2_0())
75
+ vae.set_attn_processor(AttnProcessor2_0())
76
+
77
+ # Samplers
78
+
79
+ ddim_scheduler = DDIMScheduler(
80
+ num_train_timesteps=1000,
81
+ beta_start=0.00085,
82
+ beta_end=0.012,
83
+ beta_schedule="scaled_linear",
84
+ clip_sample=False,
85
+ set_alpha_to_one=False,
86
+ steps_offset=1,
87
+ )
88
+
89
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
90
+ num_train_timesteps=1000,
91
+ beta_start=0.00085,
92
+ beta_end=0.012,
93
+ steps_offset=1
94
+ )
95
+
96
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
97
+ num_train_timesteps=1000,
98
+ beta_start=0.00085,
99
+ beta_end=0.012,
100
+ algorithm_type="sde-dpmsolver++",
101
+ use_karras_sigmas=True,
102
+ steps_offset=1
103
+ )
104
+
105
+ # Pipelines
106
+
107
+ t2i_pipe = StableDiffusionPipeline(
108
+ vae=vae,
109
+ text_encoder=text_encoder,
110
+ tokenizer=tokenizer,
111
+ unet=unet,
112
+ scheduler=dpmpp_2m_sde_karras_scheduler,
113
+ safety_checker=None,
114
+ requires_safety_checker=False,
115
+ feature_extractor=None,
116
+ image_encoder=None
117
+ )
118
+
119
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
120
+ vae=vae,
121
+ text_encoder=text_encoder,
122
+ tokenizer=tokenizer,
123
+ unet=unet,
124
+ scheduler=dpmpp_2m_sde_karras_scheduler,
125
+ safety_checker=None,
126
+ requires_safety_checker=False,
127
+ feature_extractor=None,
128
+ image_encoder=None
129
+ )
130
+
131
+
132
+ @torch.inference_mode()
133
+ def encode_prompt_inner(txt: str):
134
+ max_length = tokenizer.model_max_length
135
+ chunk_length = tokenizer.model_max_length - 2
136
+ id_start = tokenizer.bos_token_id
137
+ id_end = tokenizer.eos_token_id
138
+ id_pad = id_end
139
+
140
+ def pad(x, p, i):
141
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
142
+
143
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
144
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
145
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
146
+
147
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
148
+ conds = text_encoder(token_ids).last_hidden_state
149
+
150
+ return conds
151
+
152
+
153
+ @torch.inference_mode()
154
+ def encode_prompt_pair(positive_prompt, negative_prompt):
155
+ c = encode_prompt_inner(positive_prompt)
156
+ uc = encode_prompt_inner(negative_prompt)
157
+
158
+ c_len = float(len(c))
159
+ uc_len = float(len(uc))
160
+ max_count = max(c_len, uc_len)
161
+ c_repeat = int(math.ceil(max_count / c_len))
162
+ uc_repeat = int(math.ceil(max_count / uc_len))
163
+ max_chunk = max(len(c), len(uc))
164
+
165
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
166
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
167
+
168
+ c = torch.cat([p[None, ...] for p in c], dim=1)
169
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
170
+
171
+ return c, uc
172
+
173
+
174
+ @torch.inference_mode()
175
+ def pytorch2numpy(imgs, quant=True):
176
+ results = []
177
+ for x in imgs:
178
+ y = x.movedim(0, -1)
179
+
180
+ if quant:
181
+ y = y * 127.5 + 127.5
182
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
183
+ else:
184
+ y = y * 0.5 + 0.5
185
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
186
+
187
+ results.append(y)
188
+ return results
189
+
190
+
191
+ @torch.inference_mode()
192
+ def numpy2pytorch(imgs):
193
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
194
+ h = h.movedim(-1, 1)
195
+ return h
196
+
197
+
198
+ def resize_and_center_crop(image, target_width, target_height):
199
+ pil_image = Image.fromarray(image)
200
+ original_width, original_height = pil_image.size
201
+ scale_factor = max(target_width / original_width, target_height / original_height)
202
+ resized_width = int(round(original_width * scale_factor))
203
+ resized_height = int(round(original_height * scale_factor))
204
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
205
+ left = (resized_width - target_width) / 2
206
+ top = (resized_height - target_height) / 2
207
+ right = (resized_width + target_width) / 2
208
+ bottom = (resized_height + target_height) / 2
209
+ cropped_image = resized_image.crop((left, top, right, bottom))
210
+ return np.array(cropped_image)
211
+
212
+
213
+ def resize_without_crop(image, target_width, target_height):
214
+ pil_image = Image.fromarray(image)
215
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
216
+ return np.array(resized_image)
217
+
218
+
219
+ @torch.inference_mode()
220
+ def run_rmbg(img, sigma=0.0):
221
+ H, W, C = img.shape
222
+ assert C == 3
223
+ k = (256.0 / float(H * W)) ** 0.5
224
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
225
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
226
+ alpha = rmbg(feed)[0][0]
227
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
228
+ alpha = alpha.movedim(1, -1)[0]
229
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
230
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
231
+ return result.clip(0, 255).astype(np.uint8), alpha
232
+
233
+
234
+ @torch.inference_mode()
235
+ def process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
236
+ bg_source = BGSource(bg_source)
237
+
238
+ if bg_source == BGSource.UPLOAD:
239
+ pass
240
+ elif bg_source == BGSource.UPLOAD_FLIP:
241
+ input_bg = np.fliplr(input_bg)
242
+ elif bg_source == BGSource.GREY:
243
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
244
+ elif bg_source == BGSource.LEFT:
245
+ gradient = np.linspace(224, 32, image_width)
246
+ image = np.tile(gradient, (image_height, 1))
247
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
248
+ elif bg_source == BGSource.RIGHT:
249
+ gradient = np.linspace(32, 224, image_width)
250
+ image = np.tile(gradient, (image_height, 1))
251
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
252
+ elif bg_source == BGSource.TOP:
253
+ gradient = np.linspace(224, 32, image_height)[:, None]
254
+ image = np.tile(gradient, (1, image_width))
255
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
256
+ elif bg_source == BGSource.BOTTOM:
257
+ gradient = np.linspace(32, 224, image_height)[:, None]
258
+ image = np.tile(gradient, (1, image_width))
259
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
260
+ else:
261
+ raise 'Wrong background source!'
262
+
263
+ rng = torch.Generator(device=device).manual_seed(seed)
264
+
265
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
266
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
267
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
268
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
269
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
270
+
271
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
272
+
273
+ latents = t2i_pipe(
274
+ prompt_embeds=conds,
275
+ negative_prompt_embeds=unconds,
276
+ width=image_width,
277
+ height=image_height,
278
+ num_inference_steps=steps,
279
+ num_images_per_prompt=num_samples,
280
+ generator=rng,
281
+ output_type='latent',
282
+ guidance_scale=cfg,
283
+ cross_attention_kwargs={'concat_conds': concat_conds},
284
+ ).images.to(vae.dtype) / vae.config.scaling_factor
285
+
286
+ pixels = vae.decode(latents).sample
287
+ pixels = pytorch2numpy(pixels)
288
+ pixels = [resize_without_crop(
289
+ image=p,
290
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
291
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
292
+ for p in pixels]
293
+
294
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
295
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
296
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
297
+
298
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
299
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
300
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
301
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
302
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
303
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
304
+
305
+ latents = i2i_pipe(
306
+ image=latents,
307
+ strength=highres_denoise,
308
+ prompt_embeds=conds,
309
+ negative_prompt_embeds=unconds,
310
+ width=image_width,
311
+ height=image_height,
312
+ num_inference_steps=int(round(steps / highres_denoise)),
313
+ num_images_per_prompt=num_samples,
314
+ generator=rng,
315
+ output_type='latent',
316
+ guidance_scale=cfg,
317
+ cross_attention_kwargs={'concat_conds': concat_conds},
318
+ ).images.to(vae.dtype) / vae.config.scaling_factor
319
+
320
+ pixels = vae.decode(latents).sample
321
+ pixels = pytorch2numpy(pixels, quant=False)
322
+
323
+ return pixels, [fg, bg]
324
+
325
+
326
+ @torch.inference_mode()
327
+ def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
328
+ input_fg, matting = run_rmbg(input_fg)
329
+ results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
330
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
331
+ return results + extra_images
332
+
333
+
334
+ @torch.inference_mode()
335
+ def process_normal(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
336
+ input_fg, matting = run_rmbg(input_fg, sigma=16)
337
+
338
+ print('left ...')
339
+ left = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.LEFT.value)[0][0]
340
+
341
+ print('right ...')
342
+ right = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.RIGHT.value)[0][0]
343
+
344
+ print('bottom ...')
345
+ bottom = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.BOTTOM.value)[0][0]
346
+
347
+ print('top ...')
348
+ top = process(input_fg, input_bg, prompt, image_width, image_height, 1, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, BGSource.TOP.value)[0][0]
349
+
350
+ inner_results = [left * 2.0 - 1.0, right * 2.0 - 1.0, bottom * 2.0 - 1.0, top * 2.0 - 1.0]
351
+
352
+ ambient = (left + right + bottom + top) / 4.0
353
+ h, w, _ = ambient.shape
354
+ matting = resize_and_center_crop((matting[..., 0] * 255.0).clip(0, 255).astype(np.uint8), w, h).astype(np.float32)[..., None] / 255.0
355
+
356
+ def safa_divide(a, b):
357
+ e = 1e-5
358
+ return ((a + e) / (b + e)) - 1.0
359
+
360
+ left = safa_divide(left, ambient)
361
+ right = safa_divide(right, ambient)
362
+ bottom = safa_divide(bottom, ambient)
363
+ top = safa_divide(top, ambient)
364
+
365
+ u = (right - left) * 0.5
366
+ v = (top - bottom) * 0.5
367
+
368
+ sigma = 10.0
369
+ u = np.mean(u, axis=2)
370
+ v = np.mean(v, axis=2)
371
+ h = (1.0 - u ** 2.0 - v ** 2.0).clip(0, 1e5) ** (0.5 * sigma)
372
+ z = np.zeros_like(h)
373
+
374
+ normal = np.stack([u, v, h], axis=2)
375
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
376
+ normal = normal * matting + np.stack([z, z, 1 - z], axis=2) * (1 - matting)
377
+
378
+ results = [normal, left, right, bottom, top] + inner_results
379
+ results = [(x * 127.5 + 127.5).clip(0, 255).astype(np.uint8) for x in results]
380
+ return results
381
+
382
+
383
+ quick_prompts = [
384
+ 'beautiful woman',
385
+ 'handsome man',
386
+ 'beautiful woman, cinematic lighting',
387
+ 'handsome man, cinematic lighting',
388
+ 'beautiful woman, natural lighting',
389
+ 'handsome man, natural lighting',
390
+ 'beautiful woman, neo punk lighting, cyberpunk',
391
+ 'handsome man, neo punk lighting, cyberpunk',
392
+ ]
393
+ quick_prompts = [[x] for x in quick_prompts]
394
+
395
+
396
+ class BGSource(Enum):
397
+ UPLOAD = "Use Background Image"
398
+ UPLOAD_FLIP = "Use Flipped Background Image"
399
+ LEFT = "Left Light"
400
+ RIGHT = "Right Light"
401
+ TOP = "Top Light"
402
+ BOTTOM = "Bottom Light"
403
+ GREY = "Ambient"
404
+
405
+
406
+ block = gr.Blocks().queue()
407
+ with block:
408
+ with gr.Row():
409
+ gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
410
+ with gr.Row():
411
+ with gr.Column():
412
+ with gr.Row():
413
+ input_fg = gr.Image(source='upload', type="numpy", label="Foreground", height=480)
414
+ input_bg = gr.Image(source='upload', type="numpy", label="Background", height=480)
415
+ prompt = gr.Textbox(label="Prompt")
416
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
417
+ value=BGSource.UPLOAD.value,
418
+ label="Background Source", type='value')
419
+
420
+ example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
421
+ bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
422
+ relight_button = gr.Button(value="Relight")
423
+
424
+ with gr.Group():
425
+ with gr.Row():
426
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
427
+ seed = gr.Number(label="Seed", value=12345, precision=0)
428
+ with gr.Row():
429
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
430
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
431
+
432
+ with gr.Accordion("Advanced options", open=False):
433
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
434
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
435
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
436
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
437
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
438
+ n_prompt = gr.Textbox(label="Negative Prompt",
439
+ value='lowres, bad anatomy, bad hands, cropped, worst quality')
440
+ normal_button = gr.Button(value="Compute Normal (4x Slower)")
441
+ with gr.Column():
442
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
443
+ with gr.Row():
444
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
445
+ gr.Examples(
446
+ fn=lambda *args: [args[-1]],
447
+ examples=db_examples.background_conditioned_examples,
448
+ inputs=[
449
+ input_fg, input_bg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
450
+ ],
451
+ outputs=[result_gallery],
452
+ run_on_click=True, examples_per_page=1024
453
+ )
454
+ ips = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
455
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[result_gallery])
456
+ normal_button.click(fn=process_normal, inputs=ips, outputs=[result_gallery])
457
+ example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
458
+
459
+ def bg_gallery_selected(gal, evt: gr.SelectData):
460
+ return gal[evt.index]['name']
461
+
462
+ bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg)
463
+
464
+
465
+ block.launch(server_name='0.0.0.0')
briarmbg.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RMBG1.4 (diffusers implementation)
2
+ # Found on huggingface space of several projects
3
+ # Not sure which project is the source of this file
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+
11
+ class REBNCONV(nn.Module):
12
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
13
+ super(REBNCONV, self).__init__()
14
+
15
+ self.conv_s1 = nn.Conv2d(
16
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
17
+ )
18
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
19
+ self.relu_s1 = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x):
22
+ hx = x
23
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
24
+
25
+ return xout
26
+
27
+
28
+ def _upsample_like(src, tar):
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+ return src
31
+
32
+
33
+ ### RSU-7 ###
34
+ class RSU7(nn.Module):
35
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
36
+ super(RSU7, self).__init__()
37
+
38
+ self.in_ch = in_ch
39
+ self.mid_ch = mid_ch
40
+ self.out_ch = out_ch
41
+
42
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
43
+
44
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
45
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
46
+
47
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
48
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
49
+
50
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
51
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
52
+
53
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
54
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
55
+
56
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
57
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
58
+
59
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
60
+
61
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
62
+
63
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
64
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
65
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
69
+
70
+ def forward(self, x):
71
+ b, c, h, w = x.shape
72
+
73
+ hx = x
74
+ hxin = self.rebnconvin(hx)
75
+
76
+ hx1 = self.rebnconv1(hxin)
77
+ hx = self.pool1(hx1)
78
+
79
+ hx2 = self.rebnconv2(hx)
80
+ hx = self.pool2(hx2)
81
+
82
+ hx3 = self.rebnconv3(hx)
83
+ hx = self.pool3(hx3)
84
+
85
+ hx4 = self.rebnconv4(hx)
86
+ hx = self.pool4(hx4)
87
+
88
+ hx5 = self.rebnconv5(hx)
89
+ hx = self.pool5(hx5)
90
+
91
+ hx6 = self.rebnconv6(hx)
92
+
93
+ hx7 = self.rebnconv7(hx6)
94
+
95
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
96
+ hx6dup = _upsample_like(hx6d, hx5)
97
+
98
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
99
+ hx5dup = _upsample_like(hx5d, hx4)
100
+
101
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
102
+ hx4dup = _upsample_like(hx4d, hx3)
103
+
104
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
105
+ hx3dup = _upsample_like(hx3d, hx2)
106
+
107
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
108
+ hx2dup = _upsample_like(hx2d, hx1)
109
+
110
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+ hx = x
146
+
147
+ hxin = self.rebnconvin(hx)
148
+
149
+ hx1 = self.rebnconv1(hxin)
150
+ hx = self.pool1(hx1)
151
+
152
+ hx2 = self.rebnconv2(hx)
153
+ hx = self.pool2(hx2)
154
+
155
+ hx3 = self.rebnconv3(hx)
156
+ hx = self.pool3(hx3)
157
+
158
+ hx4 = self.rebnconv4(hx)
159
+ hx = self.pool4(hx4)
160
+
161
+ hx5 = self.rebnconv5(hx)
162
+
163
+ hx6 = self.rebnconv6(hx5)
164
+
165
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
166
+ hx5dup = _upsample_like(hx5d, hx4)
167
+
168
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
169
+ hx4dup = _upsample_like(hx4d, hx3)
170
+
171
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
172
+ hx3dup = _upsample_like(hx3d, hx2)
173
+
174
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
175
+ hx2dup = _upsample_like(hx2d, hx1)
176
+
177
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
178
+
179
+ return hx1d + hxin
180
+
181
+
182
+ ### RSU-5 ###
183
+ class RSU5(nn.Module):
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU5, self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+
200
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
201
+
202
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
204
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
205
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
206
+
207
+ def forward(self, x):
208
+ hx = x
209
+
210
+ hxin = self.rebnconvin(hx)
211
+
212
+ hx1 = self.rebnconv1(hxin)
213
+ hx = self.pool1(hx1)
214
+
215
+ hx2 = self.rebnconv2(hx)
216
+ hx = self.pool2(hx2)
217
+
218
+ hx3 = self.rebnconv3(hx)
219
+ hx = self.pool3(hx3)
220
+
221
+ hx4 = self.rebnconv4(hx)
222
+
223
+ hx5 = self.rebnconv5(hx4)
224
+
225
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
226
+ hx4dup = _upsample_like(hx4d, hx3)
227
+
228
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
229
+ hx3dup = _upsample_like(hx3d, hx2)
230
+
231
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
232
+ hx2dup = _upsample_like(hx2d, hx1)
233
+
234
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
235
+
236
+ return hx1d + hxin
237
+
238
+
239
+ ### RSU-4 ###
240
+ class RSU4(nn.Module):
241
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
+ super(RSU4, self).__init__()
243
+
244
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
245
+
246
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
247
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
248
+
249
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
250
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
251
+
252
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
255
+
256
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
259
+
260
+ def forward(self, x):
261
+ hx = x
262
+
263
+ hxin = self.rebnconvin(hx)
264
+
265
+ hx1 = self.rebnconv1(hxin)
266
+ hx = self.pool1(hx1)
267
+
268
+ hx2 = self.rebnconv2(hx)
269
+ hx = self.pool2(hx2)
270
+
271
+ hx3 = self.rebnconv3(hx)
272
+
273
+ hx4 = self.rebnconv4(hx3)
274
+
275
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
276
+ hx3dup = _upsample_like(hx3d, hx2)
277
+
278
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
279
+ hx2dup = _upsample_like(hx2d, hx1)
280
+
281
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
282
+
283
+ return hx1d + hxin
284
+
285
+
286
+ ### RSU-4F ###
287
+ class RSU4F(nn.Module):
288
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
289
+ super(RSU4F, self).__init__()
290
+
291
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
292
+
293
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
294
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
295
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
296
+
297
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
298
+
299
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
300
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
301
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
302
+
303
+ def forward(self, x):
304
+ hx = x
305
+
306
+ hxin = self.rebnconvin(hx)
307
+
308
+ hx1 = self.rebnconv1(hxin)
309
+ hx2 = self.rebnconv2(hx1)
310
+ hx3 = self.rebnconv3(hx2)
311
+
312
+ hx4 = self.rebnconv4(hx3)
313
+
314
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
315
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
316
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
317
+
318
+ return hx1d + hxin
319
+
320
+
321
+ class myrebnconv(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_ch=3,
325
+ out_ch=1,
326
+ kernel_size=3,
327
+ stride=1,
328
+ padding=1,
329
+ dilation=1,
330
+ groups=1,
331
+ ):
332
+ super(myrebnconv, self).__init__()
333
+
334
+ self.conv = nn.Conv2d(
335
+ in_ch,
336
+ out_ch,
337
+ kernel_size=kernel_size,
338
+ stride=stride,
339
+ padding=padding,
340
+ dilation=dilation,
341
+ groups=groups,
342
+ )
343
+ self.bn = nn.BatchNorm2d(out_ch)
344
+ self.rl = nn.ReLU(inplace=True)
345
+
346
+ def forward(self, x):
347
+ return self.rl(self.bn(self.conv(x)))
348
+
349
+
350
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
351
+ def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
352
+ super(BriaRMBG, self).__init__()
353
+ in_ch = config["in_ch"]
354
+ out_ch = config["out_ch"]
355
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
356
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
357
+
358
+ self.stage1 = RSU7(64, 32, 64)
359
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
360
+
361
+ self.stage2 = RSU6(64, 32, 128)
362
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
363
+
364
+ self.stage3 = RSU5(128, 64, 256)
365
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
366
+
367
+ self.stage4 = RSU4(256, 128, 512)
368
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
369
+
370
+ self.stage5 = RSU4F(512, 256, 512)
371
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
372
+
373
+ self.stage6 = RSU4F(512, 256, 512)
374
+
375
+ # decoder
376
+ self.stage5d = RSU4F(1024, 256, 512)
377
+ self.stage4d = RSU4(1024, 128, 256)
378
+ self.stage3d = RSU5(512, 64, 128)
379
+ self.stage2d = RSU6(256, 32, 64)
380
+ self.stage1d = RSU7(128, 16, 64)
381
+
382
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
383
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
384
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
385
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
386
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
387
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
388
+
389
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
390
+
391
+ def forward(self, x):
392
+ hx = x
393
+
394
+ hxin = self.conv_in(hx)
395
+ # hx = self.pool_in(hxin)
396
+
397
+ # stage 1
398
+ hx1 = self.stage1(hxin)
399
+ hx = self.pool12(hx1)
400
+
401
+ # stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ # stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ # stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ # stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ # stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6, hx5)
420
+
421
+ # -------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
423
+ hx5dup = _upsample_like(hx5d, hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
426
+ hx4dup = _upsample_like(hx4d, hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
429
+ hx3dup = _upsample_like(hx3d, hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
432
+ hx2dup = _upsample_like(hx2d, hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
435
+
436
+ # side output
437
+ d1 = self.side1(hx1d)
438
+ d1 = _upsample_like(d1, x)
439
+
440
+ d2 = self.side2(hx2d)
441
+ d2 = _upsample_like(d2, x)
442
+
443
+ d3 = self.side3(hx3d)
444
+ d3 = _upsample_like(d3, x)
445
+
446
+ d4 = self.side4(hx4d)
447
+ d4 = _upsample_like(d4, x)
448
+
449
+ d5 = self.side5(hx5d)
450
+ d5 = _upsample_like(d5, x)
451
+
452
+ d6 = self.side6(hx6)
453
+ d6 = _upsample_like(d6, x)
454
+
455
+ return [
456
+ F.sigmoid(d1),
457
+ F.sigmoid(d2),
458
+ F.sigmoid(d3),
459
+ F.sigmoid(d4),
460
+ F.sigmoid(d5),
461
+ F.sigmoid(d6),
462
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
db_examples.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ foreground_conditioned_examples = []
2
+ bg_samples = []
3
+ background_conditioned_examples = []
models/model_download_here ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch
3
+ torchvision
4
+ diffusers==0.31.0
5
+ accelerate==1.1.1
6
+ transformers==4.46.2
7
+ sentencepiece==0.2.0
8
+ opencv-python
9
+ safetensors
10
+ pillow
11
+ einops
12
+ peft
13
+ pyzipper
14
+ python-multipart==0.0.12