clementchadebec commited on
Commit
a88bb44
·
verified ·
1 Parent(s): 4949f93

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +216 -0
  2. requirements.txt +14 -0
  3. utils.py +177 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from copy import deepcopy
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import PIL
7
+ import spaces
8
+ import torch
9
+ import yaml
10
+ from gradio_imageslider import ImageSlider
11
+ from huggingface_hub import hf_hub_download
12
+ from PIL import Image
13
+ from safetensors.torch import load_file
14
+ from torchvision.transforms import ToPILImage, ToTensor
15
+ from transformers import AutoModelForImageSegmentation
16
+ from utils import extract_object, get_model_from_config, resize_and_center_crop
17
+
18
+ ASPECT_RATIOS = {
19
+ str(512 / 2048): (512, 2048),
20
+ str(1024 / 1024): (1024, 1024),
21
+ str(2048 / 512): (2048, 512),
22
+ str(896 / 1152): (896, 1152),
23
+ str(1152 / 896): (1152, 896),
24
+ str(512 / 1920): (512, 1920),
25
+ str(640 / 1536): (640, 1536),
26
+ str(768 / 1280): (768, 1280),
27
+ str(1280 / 768): (1280, 768),
28
+ str(1536 / 640): (1536, 640),
29
+ str(1920 / 512): (1920, 512),
30
+ }
31
+
32
+ # download the config and model
33
+ MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "relight.safetensors")
34
+ CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "relight.yaml")
35
+
36
+ with open(CONFIG_PATH, "r") as f:
37
+ config = yaml.safe_load(f)
38
+ model = get_model_from_config(**config)
39
+ sd = load_file(MODEL_PATH)
40
+ model.load_state_dict(sd, strict=True)
41
+ model.to("cuda").to(torch.bfloat16)
42
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
43
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
44
+ ).cuda()
45
+ image_size = (1024, 1024)
46
+
47
+
48
+ @spaces.GPU
49
+ def evaluate(
50
+ fg_image: PIL.Image.Image,
51
+ bg_image: PIL.Image.Image,
52
+ num_sampling_steps: int = 1,
53
+ ):
54
+
55
+ ori_h_bg, ori_w_bg = fg_image.size
56
+ ar_bg = ori_h_bg / ori_w_bg
57
+ closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
58
+ dimensions_bg = ASPECT_RATIOS[closest_ar_bg]
59
+
60
+ _, fg_mask = extract_object(birefnet, deepcopy(fg_image))
61
+
62
+ fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1])
63
+ fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1])
64
+ bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
65
+
66
+ img_pasted = Image.composite(fg_image, bg_image, fg_mask)
67
+
68
+ img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
69
+ batch = {
70
+ "source_image": img_pasted_tensor.cuda().to(torch.bfloat16),
71
+ }
72
+
73
+ z_source = model.vae.encode(batch[model.source_key])
74
+
75
+ output_image = model.sample(
76
+ z=z_source,
77
+ num_steps=num_sampling_steps,
78
+ conditioner_inputs=batch,
79
+ max_samples=1,
80
+ ).clamp(-1, 1)
81
+
82
+ output_image = (output_image[0].float().cpu() + 1) / 2
83
+ output_image = ToPILImage()(output_image)
84
+
85
+ # paste the output image on the background image
86
+ output_image = Image.composite(output_image, bg_image, fg_mask)
87
+
88
+ output_image.resize((ori_h_bg, ori_w_bg))
89
+ print(output_image.size, img_pasted.size)
90
+
91
+ return (np.array(img_pasted), np.array(output_image))
92
+
93
+
94
+ with gr.Blocks(title="LBM Object Relighting") as demo:
95
+ gr.Markdown(
96
+ f"""
97
+ # Object Relighting with Latent Bridge Matching
98
+ This is an interactive demo of [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](https://arxiv.org/abs/2403.03025) *by Jasper Research*. We are internally exploring the possibility of releasing the model. If you enjoy the space, please also promote *open-source* by giving a ⭐ to the <a href='https://github.com/gojasper/LBM' target='_blank'>Github Repo</a>.
99
+ """
100
+ )
101
+ gr.Markdown(
102
+ "💡 *Hint:* To better appreciate the low latency of our method, run the demo locally !"
103
+ )
104
+ with gr.Row():
105
+ with gr.Column():
106
+ with gr.Row():
107
+ fg_image = gr.Image(
108
+ type="pil",
109
+ label="Input Image",
110
+ image_mode="RGB",
111
+ height=360,
112
+ # width=360,
113
+ )
114
+ bg_image = gr.Image(
115
+ type="pil",
116
+ label="Target Background",
117
+ image_mode="RGB",
118
+ height=360,
119
+ # width=360,
120
+ )
121
+
122
+ with gr.Row():
123
+ submit_button = gr.Button("Relight", variant="primary")
124
+ with gr.Row():
125
+ num_inference_steps = gr.Slider(
126
+ minimum=1,
127
+ maximum=4,
128
+ value=1,
129
+ step=1,
130
+ label="Number of Inference Steps",
131
+ )
132
+
133
+ bg_gallery = gr.Gallery(
134
+ # height=450,
135
+ object_fit="contain",
136
+ label="Background List",
137
+ value=[path for path in glob.glob("examples/backgrounds/*.jpg")],
138
+ columns=5,
139
+ allow_preview=False,
140
+ )
141
+
142
+ with gr.Column():
143
+ output_slider = ImageSlider(label="Composite vs LBM", type="numpy")
144
+ output_slider.upload(
145
+ fn=evaluate,
146
+ inputs=[fg_image, bg_image, num_inference_steps],
147
+ outputs=[output_slider],
148
+ )
149
+
150
+ submit_button.click(
151
+ evaluate,
152
+ inputs=[fg_image, bg_image, num_inference_steps],
153
+ outputs=[output_slider],
154
+ show_progress=False,
155
+ show_api=False,
156
+ )
157
+
158
+ with gr.Row():
159
+ gr.Examples(
160
+ fn=evaluate,
161
+ examples=[
162
+ [
163
+ "examples/foregrounds/2.jpg",
164
+ "examples/backgrounds/14.jpg",
165
+ 1,
166
+ ],
167
+ [
168
+ "examples/foregrounds/10.jpg",
169
+ "examples/backgrounds/4.jpg",
170
+ 1,
171
+ ],
172
+ [
173
+ "examples/foregrounds/11.jpg",
174
+ "examples/backgrounds/24.jpg",
175
+ 1,
176
+ ],
177
+ [
178
+ "examples/foregrounds/19.jpg",
179
+ "examples/backgrounds/3.jpg",
180
+ 1,
181
+ ],
182
+ [
183
+ "examples/foregrounds/4.jpg",
184
+ "examples/backgrounds/6.jpg",
185
+ 1,
186
+ ],
187
+ [
188
+ "examples/foregrounds/14.jpg",
189
+ "examples/backgrounds/22.jpg",
190
+ 1,
191
+ ],
192
+ [
193
+ "examples/foregrounds/12.jpg",
194
+ "examples/backgrounds/1.jpg",
195
+ 1,
196
+ ],
197
+ ],
198
+ inputs=[fg_image, bg_image, num_inference_steps],
199
+ outputs=[output_slider],
200
+ run_on_click=True,
201
+ )
202
+
203
+ gr.Markdown("**Disclaimer:**")
204
+ gr.Markdown(
205
+ "This demo is only for research purpose. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. Jasper provides the tools, but the responsibility for their use lies with the individual user."
206
+ )
207
+
208
+ def bg_gallery_selected(gal, evt: gr.SelectData):
209
+ print(gal, evt.index)
210
+ return gal[evt.index][0]
211
+
212
+ bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
213
+
214
+ if __name__ == "__main__":
215
+
216
+ demo.queue().launch(share=True, show_api=False)
requirements.txt CHANGED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.2.1
2
+ diffusers>=0.32.2
3
+ torch==2.5.1
4
+ torchvision>=0.20.0
5
+ torchaudio>=2.5.0
6
+ gradio==4.44.0
7
+ gradio-imageslider
8
+ lpips==0.1.4
9
+ pydantic>=2.6.1
10
+ transformers==4.42.3
11
+ xformers
12
+ einops==0.7.0
13
+ timm==0.9.16
14
+ kornia==0.8.0
utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import FlowMatchEulerDiscreteScheduler
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ from lbm.models.embedders import (
10
+ ConditionerWrapper,
11
+ LatentsConcatEmbedder,
12
+ LatentsConcatEmbedderConfig,
13
+ )
14
+ from lbm.models.lbm import LBMConfig, LBMModel
15
+ from lbm.models.unets import DiffusersUNet2DCondWrapper
16
+ from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
17
+
18
+
19
+ def get_model_from_config(
20
+ backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0",
21
+ vae_num_channels: int = 4,
22
+ unet_input_channels: int = 4,
23
+ timestep_sampling: str = "log_normal",
24
+ selected_timesteps: List[float] = None,
25
+ prob: List[float] = None,
26
+ conditioning_images_keys: List[str] = [],
27
+ conditioning_masks_keys: List[str] = ["mask"],
28
+ source_key: str = "source_image",
29
+ target_key: str = "source_image_paste",
30
+ bridge_noise_sigma: float = 0.0,
31
+ ):
32
+
33
+ conditioners = []
34
+
35
+ denoiser = DiffusersUNet2DCondWrapper(
36
+ in_channels=unet_input_channels, # Add downsampled_image
37
+ out_channels=vae_num_channels,
38
+ center_input_sample=False,
39
+ flip_sin_to_cos=True,
40
+ freq_shift=0,
41
+ down_block_types=[
42
+ "DownBlock2D",
43
+ "CrossAttnDownBlock2D",
44
+ "CrossAttnDownBlock2D",
45
+ ],
46
+ mid_block_type="UNetMidBlock2DCrossAttn",
47
+ up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
48
+ only_cross_attention=False,
49
+ block_out_channels=[320, 640, 1280],
50
+ layers_per_block=2,
51
+ downsample_padding=1,
52
+ mid_block_scale_factor=1,
53
+ dropout=0.0,
54
+ act_fn="silu",
55
+ norm_num_groups=32,
56
+ norm_eps=1e-05,
57
+ cross_attention_dim=[320, 640, 1280],
58
+ transformer_layers_per_block=[1, 2, 10],
59
+ reverse_transformer_layers_per_block=None,
60
+ encoder_hid_dim=None,
61
+ encoder_hid_dim_type=None,
62
+ attention_head_dim=[5, 10, 20],
63
+ num_attention_heads=None,
64
+ dual_cross_attention=False,
65
+ use_linear_projection=True,
66
+ class_embed_type=None,
67
+ addition_embed_type=None,
68
+ addition_time_embed_dim=None,
69
+ num_class_embeds=None,
70
+ upcast_attention=None,
71
+ resnet_time_scale_shift="default",
72
+ resnet_skip_time_act=False,
73
+ resnet_out_scale_factor=1.0,
74
+ time_embedding_type="positional",
75
+ time_embedding_dim=None,
76
+ time_embedding_act_fn=None,
77
+ timestep_post_act=None,
78
+ time_cond_proj_dim=None,
79
+ conv_in_kernel=3,
80
+ conv_out_kernel=3,
81
+ projection_class_embeddings_input_dim=None,
82
+ attention_type="default",
83
+ class_embeddings_concat=False,
84
+ mid_block_only_cross_attention=None,
85
+ cross_attention_norm=None,
86
+ addition_embed_type_num_heads=64,
87
+ ).to(torch.bfloat16)
88
+
89
+ if conditioning_images_keys != [] or conditioning_masks_keys != []:
90
+
91
+ latents_concat_embedder_config = LatentsConcatEmbedderConfig(
92
+ image_keys=conditioning_images_keys,
93
+ mask_keys=conditioning_masks_keys,
94
+ )
95
+ latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config)
96
+ latent_concat_embedder.freeze()
97
+ conditioners.append(latent_concat_embedder)
98
+
99
+ # Wrap conditioners and set to device
100
+ conditioner = ConditionerWrapper(
101
+ conditioners=conditioners,
102
+ )
103
+
104
+ ## VAE ##
105
+ # Get VAE model
106
+ vae_config = AutoencoderKLDiffusersConfig(
107
+ version=backbone_signature,
108
+ subfolder="vae",
109
+ tiling_size=(128, 128),
110
+ )
111
+ vae = AutoencoderKLDiffusers(vae_config).to(torch.bfloat16)
112
+ vae.freeze()
113
+ vae.to(torch.bfloat16)
114
+
115
+ ## Diffusion Model ##
116
+ # Get diffusion model
117
+ config = LBMConfig(
118
+ source_key=source_key,
119
+ target_key=target_key,
120
+ timestep_sampling=timestep_sampling,
121
+ selected_timesteps=selected_timesteps,
122
+ prob=prob,
123
+ bridge_noise_sigma=bridge_noise_sigma,
124
+ )
125
+
126
+ sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
127
+ backbone_signature,
128
+ subfolder="scheduler",
129
+ )
130
+
131
+ model = LBMModel(
132
+ config,
133
+ denoiser=denoiser,
134
+ sampling_noise_scheduler=sampling_noise_scheduler,
135
+ vae=vae,
136
+ conditioner=conditioner,
137
+ ).to(torch.bfloat16)
138
+
139
+ return model
140
+
141
+
142
+ def extract_object(birefnet, img):
143
+ # Data settings
144
+ image_size = (1024, 1024)
145
+ transform_image = transforms.Compose(
146
+ [
147
+ transforms.Resize(image_size),
148
+ transforms.ToTensor(),
149
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
150
+ ]
151
+ )
152
+
153
+ image = img
154
+ input_images = transform_image(image).unsqueeze(0).cuda()
155
+
156
+ # Prediction
157
+ with torch.no_grad():
158
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
159
+ pred = preds[0].squeeze()
160
+ pred_pil = transforms.ToPILImage()(pred)
161
+ mask = pred_pil.resize(image.size)
162
+ image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask)
163
+ return image, mask
164
+
165
+
166
+ def resize_and_center_crop(image, target_width, target_height):
167
+ original_width, original_height = image.size
168
+ scale_factor = max(target_width / original_width, target_height / original_height)
169
+ resized_width = int(round(original_width * scale_factor))
170
+ resized_height = int(round(original_height * scale_factor))
171
+ resized_image = image.resize((resized_width, resized_height), Image.LANCZOS)
172
+ left = (resized_width - target_width) / 2
173
+ top = (resized_height - target_height) / 2
174
+ right = (resized_width + target_width) / 2
175
+ bottom = (resized_height + target_height) / 2
176
+ cropped_image = resized_image.crop((left, top, right, bottom))
177
+ return cropped_image