blanchon commited on
Commit
c9ba14c
·
1 Parent(s): e4ee2ca

Remove rgb2x

Browse files
rgb2x/example/Castlereagh_corridor_photo.png DELETED

Git LFS Details

  • SHA256: 8f77a445168dd92b97e214034f11291b8b3c0d98f3f12e34d591f56c39998fb4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
rgb2x/gradio_demo_rgb2x.py DELETED
@@ -1,166 +0,0 @@
1
- import spaces
2
- import os
3
- from typing import cast
4
- import gradio as gr
5
- from PIL import Image
6
- import torch
7
- import torchvision
8
- from diffusers import DDIMScheduler
9
- from load_image import load_exr_image, load_ldr_image
10
- from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
11
-
12
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
13
-
14
- current_directory = os.path.dirname(os.path.abspath(__file__))
15
-
16
- _pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
17
- "zheng95z/rgb-to-x",
18
- torch_dtype=torch.float16,
19
- cache_dir=os.path.join(current_directory, "model_cache"),
20
- ).to("cuda")
21
- pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe)
22
- pipe.scheduler = DDIMScheduler.from_config(
23
- pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
24
- )
25
- pipe.set_progress_bar_config(disable=True)
26
- pipe.to("cuda")
27
- pipe = cast(StableDiffusionAOVMatEstPipeline, pipe)
28
-
29
-
30
- @spaces.GPU
31
- def generate(
32
- photo,
33
- seed: int,
34
- inference_step: int,
35
- num_samples: int,
36
- ) -> list[Image.Image]:
37
- generator = torch.Generator(device="cuda").manual_seed(seed)
38
-
39
- if photo.name.endswith(".exr"):
40
- photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
41
- elif (
42
- photo.name.endswith(".png")
43
- or photo.name.endswith(".jpg")
44
- or photo.name.endswith(".jpeg")
45
- ):
46
- photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
47
-
48
- # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
49
- old_height = photo.shape[1]
50
- old_width = photo.shape[2]
51
- new_height = old_height
52
- new_width = old_width
53
- radio = old_height / old_width
54
- max_side = 1000
55
- if old_height > old_width:
56
- new_height = max_side
57
- new_width = int(new_height / radio)
58
- else:
59
- new_width = max_side
60
- new_height = int(new_width * radio)
61
-
62
- if new_width % 8 != 0 or new_height % 8 != 0:
63
- new_width = new_width // 8 * 8
64
- new_height = new_height // 8 * 8
65
-
66
- photo = torchvision.transforms.Resize((new_height, new_width))(photo)
67
-
68
- required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
69
- prompts = {
70
- "albedo": "Albedo (diffuse basecolor)",
71
- "normal": "Camera-space Normal",
72
- "roughness": "Roughness",
73
- "metallic": "Metallicness",
74
- "irradiance": "Irradiance (diffuse lighting)",
75
- }
76
-
77
- return_list = []
78
- for i in range(num_samples):
79
- for aov_name in required_aovs:
80
- prompt = prompts[aov_name]
81
- generated_image = pipe(
82
- prompt=prompt,
83
- photo=photo,
84
- num_inference_steps=inference_step,
85
- height=new_height,
86
- width=new_width,
87
- generator=generator,
88
- required_aovs=[aov_name],
89
- ).images[0][0] # type: ignore
90
-
91
- generated_image = torchvision.transforms.Resize((old_height, old_width))(
92
- generated_image
93
- )
94
-
95
- generated_image = (generated_image, f"Generated {aov_name} {i}")
96
- return_list.append(generated_image)
97
-
98
- return return_list
99
-
100
-
101
- with gr.Blocks() as demo:
102
- with gr.Row():
103
- gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
104
- with gr.Row():
105
- # Input side
106
- with gr.Column():
107
- gr.Markdown("### Given Image")
108
- photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
109
-
110
- gr.Markdown("### Parameters")
111
- run_button = gr.Button(value="Run")
112
- with gr.Accordion("Advanced options", open=False):
113
- seed = gr.Slider(
114
- label="Seed",
115
- minimum=-1,
116
- maximum=2147483647,
117
- step=1,
118
- randomize=True,
119
- )
120
- inference_step = gr.Slider(
121
- label="Inference Step",
122
- minimum=1,
123
- maximum=100,
124
- step=1,
125
- value=50,
126
- )
127
- num_samples = gr.Slider(
128
- label="Samples",
129
- minimum=1,
130
- maximum=100,
131
- step=1,
132
- value=1,
133
- )
134
-
135
- # Output side
136
- with gr.Column():
137
- gr.Markdown("### Output Gallery")
138
- result_gallery = gr.Gallery(
139
- label="Output",
140
- show_label=False,
141
- elem_id="gallery",
142
- columns=2,
143
- )
144
- examples = gr.Examples(
145
- examples=[
146
- [
147
- "rgb2x/example/Castlereagh_corridor_photo.png",
148
- ]
149
- ],
150
- inputs=[photo],
151
- outputs=[result_gallery],
152
- fn=generate,
153
- cache_mode="eager",
154
- cache_examples=True,
155
- )
156
-
157
- run_button.click(
158
- fn=generate,
159
- inputs=[photo, seed, inference_step, num_samples],
160
- outputs=result_gallery,
161
- queue=True,
162
- )
163
-
164
-
165
- if __name__ == "__main__":
166
- demo.launch(debug=False, share=False, show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rgb2x/load_image.py DELETED
@@ -1,119 +0,0 @@
1
- import os
2
-
3
- import cv2
4
- import torch
5
-
6
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
7
- import numpy as np
8
-
9
-
10
- def convert_rgb_2_XYZ(rgb):
11
- # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
12
- # rgb: (h, w, 3)
13
- # XYZ: (h, w, 3)
14
- XYZ = torch.ones_like(rgb)
15
- XYZ[:, :, 0] = (
16
- 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
17
- )
18
- XYZ[:, :, 1] = (
19
- 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
20
- )
21
- XYZ[:, :, 2] = (
22
- 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
23
- )
24
- return XYZ
25
-
26
-
27
- def convert_XYZ_2_Yxy(XYZ):
28
- # XYZ: (h, w, 3)
29
- # Yxy: (h, w, 3)
30
- Yxy = torch.ones_like(XYZ)
31
- Yxy[:, :, 0] = XYZ[:, :, 1]
32
- sum = torch.sum(XYZ, dim=2)
33
- inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
34
- Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
35
- Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
36
- return Yxy
37
-
38
-
39
- def convert_rgb_2_Yxy(rgb):
40
- # rgb: (h, w, 3)
41
- # Yxy: (h, w, 3)
42
- return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
43
-
44
-
45
- def convert_XYZ_2_rgb(XYZ):
46
- # XYZ: (h, w, 3)
47
- # rgb: (h, w, 3)
48
- rgb = torch.ones_like(XYZ)
49
- rgb[:, :, 0] = (
50
- 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
51
- )
52
- rgb[:, :, 1] = (
53
- -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
54
- )
55
- rgb[:, :, 2] = (
56
- 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
57
- )
58
- return rgb
59
-
60
-
61
- def convert_Yxy_2_XYZ(Yxy):
62
- # Yxy: (h, w, 3)
63
- # XYZ: (h, w, 3)
64
- XYZ = torch.ones_like(Yxy)
65
- XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
66
- XYZ[:, :, 1] = Yxy[:, :, 0]
67
- XYZ[:, :, 2] = (
68
- (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
69
- / torch.clamp(Yxy[:, :, 2], min=1e-4)
70
- * Yxy[:, :, 0]
71
- )
72
- return XYZ
73
-
74
-
75
- def convert_Yxy_2_rgb(Yxy):
76
- # Yxy: (h, w, 3)
77
- # rgb: (h, w, 3)
78
- return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
79
-
80
-
81
- def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
82
- # Load png or jpg image
83
- image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
84
- image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
85
- image[~torch.isfinite(image)] = 0
86
- if from_srgb:
87
- # Convert from sRGB to linear RGB
88
- image = image**2.2
89
- if clamp:
90
- image = torch.clamp(image, min=0.0, max=1.0)
91
- if normalize:
92
- # Normalize to [-1, 1]
93
- image = image * 2.0 - 1.0
94
- image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
95
- return image.permute(2, 0, 1) # returns (c, h, w)
96
-
97
-
98
- def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
99
- image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
100
- image = torch.from_numpy(image.astype("float32")) # (h, w, c)
101
- image[~torch.isfinite(image)] = 0
102
- if tonemaping:
103
- # Exposure adjuestment
104
- image_Yxy = convert_rgb_2_Yxy(image)
105
- lum = (
106
- image[:, :, 0:1] * 0.2125
107
- + image[:, :, 1:2] * 0.7154
108
- + image[:, :, 2:3] * 0.0721
109
- )
110
- lum = torch.log(torch.clamp(lum, min=1e-6))
111
- lum_mean = torch.exp(torch.mean(lum))
112
- lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
113
- image_Yxy[:, :, 0:1] = lp
114
- image = convert_Yxy_2_rgb(image_Yxy)
115
- if clamp:
116
- image = torch.clamp(image, min=0.0, max=1.0)
117
- if normalize:
118
- image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
119
- return image.permute(2, 0, 1) # returns (c, h, w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rgb2x/pipeline_rgb2x.py DELETED
@@ -1,821 +0,0 @@
1
- import inspect
2
- from dataclasses import dataclass
3
- from typing import Callable, List, Optional, Union
4
-
5
- import numpy as np
6
- import PIL
7
- import torch
8
- from diffusers.configuration_utils import register_to_config
9
- from diffusers.image_processor import VaeImageProcessor
10
- from diffusers.loaders import (
11
- LoraLoaderMixin,
12
- TextualInversionLoaderMixin,
13
- )
14
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
15
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
17
- rescale_noise_cfg,
18
- )
19
- from diffusers.schedulers import KarrasDiffusionSchedulers
20
- from diffusers.utils import (
21
- CONFIG_NAME,
22
- BaseOutput,
23
- deprecate,
24
- logging,
25
- )
26
- from diffusers.utils.torch_utils import randn_tensor
27
- from transformers import CLIPTextModel, CLIPTokenizer
28
-
29
- logger = logging.get_logger(__name__)
30
-
31
-
32
- class VaeImageProcrssorAOV(VaeImageProcessor):
33
- """
34
- Image processor for VAE AOV.
35
-
36
- Args:
37
- do_resize (`bool`, *optional*, defaults to `True`):
38
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
39
- vae_scale_factor (`int`, *optional*, defaults to `8`):
40
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
41
- resample (`str`, *optional*, defaults to `lanczos`):
42
- Resampling filter to use when resizing the image.
43
- do_normalize (`bool`, *optional*, defaults to `True`):
44
- Whether to normalize the image to [-1,1].
45
- """
46
-
47
- config_name = CONFIG_NAME
48
-
49
- @register_to_config
50
- def __init__(
51
- self,
52
- do_resize: bool = True,
53
- vae_scale_factor: int = 8,
54
- resample: str = "lanczos",
55
- do_normalize: bool = True,
56
- ):
57
- super().__init__()
58
-
59
- def postprocess(
60
- self,
61
- image: torch.FloatTensor,
62
- output_type: str = "pil",
63
- do_denormalize: Optional[List[bool]] = None,
64
- do_gamma_correction: bool = True,
65
- ):
66
- if not isinstance(image, torch.Tensor):
67
- raise ValueError(
68
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
69
- )
70
- if output_type not in ["latent", "pt", "np", "pil"]:
71
- deprecation_message = (
72
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
73
- "`pil`, `np`, `pt`, `latent`"
74
- )
75
- deprecate(
76
- "Unsupported output_type",
77
- "1.0.0",
78
- deprecation_message,
79
- standard_warn=False,
80
- )
81
- output_type = "np"
82
-
83
- if output_type == "latent":
84
- return image
85
-
86
- if do_denormalize is None:
87
- do_denormalize = [self.config.do_normalize] * image.shape[0]
88
-
89
- image = torch.stack(
90
- [
91
- self.denormalize(image[i]) if do_denormalize[i] else image[i]
92
- for i in range(image.shape[0])
93
- ]
94
- )
95
-
96
- # Gamma correction
97
- if do_gamma_correction:
98
- image = torch.pow(image, 1.0 / 2.2)
99
-
100
- if output_type == "pt":
101
- return image
102
-
103
- image = self.pt_to_numpy(image)
104
-
105
- if output_type == "np":
106
- return image
107
-
108
- if output_type == "pil":
109
- return self.numpy_to_pil(image)
110
-
111
- def preprocess_normal(
112
- self,
113
- image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
114
- height: Optional[int] = None,
115
- width: Optional[int] = None,
116
- ) -> torch.Tensor:
117
- image = torch.stack([image], axis=0)
118
- return image
119
-
120
-
121
- @dataclass
122
- class StableDiffusionAOVPipelineOutput(BaseOutput):
123
- """
124
- Output class for Stable Diffusion AOV pipelines.
125
-
126
- Args:
127
- images (`List[PIL.Image.Image]` or `np.ndarray`)
128
- List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
129
- num_channels)`.
130
- nsfw_content_detected (`List[bool]`)
131
- List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
132
- `None` if safety checking could not be performed.
133
- """
134
-
135
- images: Union[List[PIL.Image.Image], np.ndarray]
136
-
137
-
138
- class StableDiffusionAOVMatEstPipeline(
139
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
140
- ):
141
- r"""
142
- Pipeline for AOVs.
143
-
144
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
145
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
146
-
147
- The pipeline also inherits the following loading methods:
148
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
149
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
150
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
151
-
152
- Args:
153
- vae ([`AutoencoderKL`]):
154
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
155
- text_encoder ([`~transformers.CLIPTextModel`]):
156
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
157
- tokenizer ([`~transformers.CLIPTokenizer`]):
158
- A `CLIPTokenizer` to tokenize text.
159
- unet ([`UNet2DConditionModel`]):
160
- A `UNet2DConditionModel` to denoise the encoded image latents.
161
- scheduler ([`SchedulerMixin`]):
162
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
163
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
164
- """
165
-
166
- def __init__(
167
- self,
168
- vae: AutoencoderKL,
169
- text_encoder: CLIPTextModel,
170
- tokenizer: CLIPTokenizer,
171
- unet: UNet2DConditionModel,
172
- scheduler: KarrasDiffusionSchedulers,
173
- ):
174
- super().__init__()
175
-
176
- self.register_modules(
177
- vae=vae,
178
- text_encoder=text_encoder,
179
- tokenizer=tokenizer,
180
- unet=unet,
181
- scheduler=scheduler,
182
- )
183
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
184
- self.image_processor = VaeImageProcrssorAOV(
185
- vae_scale_factor=self.vae_scale_factor
186
- )
187
- self.register_to_config()
188
-
189
- def _encode_prompt(
190
- self,
191
- prompt,
192
- device,
193
- num_images_per_prompt,
194
- do_classifier_free_guidance,
195
- negative_prompt=None,
196
- prompt_embeds: Optional[torch.FloatTensor] = None,
197
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
198
- ):
199
- r"""
200
- Encodes the prompt into text encoder hidden states.
201
-
202
- Args:
203
- prompt (`str` or `List[str]`, *optional*):
204
- prompt to be encoded
205
- device: (`torch.device`):
206
- torch device
207
- num_images_per_prompt (`int`):
208
- number of images that should be generated per prompt
209
- do_classifier_free_guidance (`bool`):
210
- whether to use classifier free guidance or not
211
- negative_ prompt (`str` or `List[str]`, *optional*):
212
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
213
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
214
- less than `1`).
215
- prompt_embeds (`torch.FloatTensor`, *optional*):
216
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
217
- provided, text embeddings will be generated from `prompt` input argument.
218
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
219
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
220
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
221
- argument.
222
- """
223
- if prompt is not None and isinstance(prompt, str):
224
- batch_size = 1
225
- elif prompt is not None and isinstance(prompt, list):
226
- batch_size = len(prompt)
227
- else:
228
- batch_size = prompt_embeds.shape[0]
229
-
230
- if prompt_embeds is None:
231
- # textual inversion: procecss multi-vector tokens if necessary
232
- if isinstance(self, TextualInversionLoaderMixin):
233
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
234
-
235
- text_inputs = self.tokenizer(
236
- prompt,
237
- padding="max_length",
238
- max_length=self.tokenizer.model_max_length,
239
- truncation=True,
240
- return_tensors="pt",
241
- )
242
- text_input_ids = text_inputs.input_ids
243
- untruncated_ids = self.tokenizer(
244
- prompt, padding="longest", return_tensors="pt"
245
- ).input_ids
246
-
247
- if untruncated_ids.shape[-1] >= text_input_ids.shape[
248
- -1
249
- ] and not torch.equal(text_input_ids, untruncated_ids):
250
- removed_text = self.tokenizer.batch_decode(
251
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
252
- )
253
- logger.warning(
254
- "The following part of your input was truncated because CLIP can only handle sequences up to"
255
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
256
- )
257
-
258
- if (
259
- hasattr(self.text_encoder.config, "use_attention_mask")
260
- and self.text_encoder.config.use_attention_mask
261
- ):
262
- attention_mask = text_inputs.attention_mask.to(device)
263
- else:
264
- attention_mask = None
265
-
266
- prompt_embeds = self.text_encoder(
267
- text_input_ids.to(device),
268
- attention_mask=attention_mask,
269
- )
270
- prompt_embeds = prompt_embeds[0]
271
-
272
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
273
-
274
- bs_embed, seq_len, _ = prompt_embeds.shape
275
- # duplicate text embeddings for each generation per prompt, using mps friendly method
276
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
277
- prompt_embeds = prompt_embeds.view(
278
- bs_embed * num_images_per_prompt, seq_len, -1
279
- )
280
-
281
- # get unconditional embeddings for classifier free guidance
282
- if do_classifier_free_guidance and negative_prompt_embeds is None:
283
- uncond_tokens: List[str]
284
- if negative_prompt is None:
285
- uncond_tokens = [""] * batch_size
286
- elif type(prompt) is not type(negative_prompt):
287
- raise TypeError(
288
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
289
- f" {type(prompt)}."
290
- )
291
- elif isinstance(negative_prompt, str):
292
- uncond_tokens = [negative_prompt]
293
- elif batch_size != len(negative_prompt):
294
- raise ValueError(
295
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
296
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
297
- " the batch size of `prompt`."
298
- )
299
- else:
300
- uncond_tokens = negative_prompt
301
-
302
- # textual inversion: procecss multi-vector tokens if necessary
303
- if isinstance(self, TextualInversionLoaderMixin):
304
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
305
-
306
- max_length = prompt_embeds.shape[1]
307
- uncond_input = self.tokenizer(
308
- uncond_tokens,
309
- padding="max_length",
310
- max_length=max_length,
311
- truncation=True,
312
- return_tensors="pt",
313
- )
314
-
315
- if (
316
- hasattr(self.text_encoder.config, "use_attention_mask")
317
- and self.text_encoder.config.use_attention_mask
318
- ):
319
- attention_mask = uncond_input.attention_mask.to(device)
320
- else:
321
- attention_mask = None
322
-
323
- negative_prompt_embeds = self.text_encoder(
324
- uncond_input.input_ids.to(device),
325
- attention_mask=attention_mask,
326
- )
327
- negative_prompt_embeds = negative_prompt_embeds[0]
328
-
329
- if do_classifier_free_guidance:
330
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
331
- seq_len = negative_prompt_embeds.shape[1]
332
-
333
- negative_prompt_embeds = negative_prompt_embeds.to(
334
- dtype=self.text_encoder.dtype, device=device
335
- )
336
-
337
- negative_prompt_embeds = negative_prompt_embeds.repeat(
338
- 1, num_images_per_prompt, 1
339
- )
340
- negative_prompt_embeds = negative_prompt_embeds.view(
341
- batch_size * num_images_per_prompt, seq_len, -1
342
- )
343
-
344
- # For classifier free guidance, we need to do two forward passes.
345
- # Here we concatenate the unconditional and text embeddings into a single batch
346
- # to avoid doing two forward passes
347
- # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
348
- prompt_embeds = torch.cat(
349
- [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
350
- )
351
-
352
- return prompt_embeds
353
-
354
- def prepare_extra_step_kwargs(self, generator, eta):
355
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
356
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
357
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
358
- # and should be between [0, 1]
359
-
360
- accepts_eta = "eta" in set(
361
- inspect.signature(self.scheduler.step).parameters.keys()
362
- )
363
- extra_step_kwargs = {}
364
- if accepts_eta:
365
- extra_step_kwargs["eta"] = eta
366
-
367
- # check if the scheduler accepts generator
368
- accepts_generator = "generator" in set(
369
- inspect.signature(self.scheduler.step).parameters.keys()
370
- )
371
- if accepts_generator:
372
- extra_step_kwargs["generator"] = generator
373
- return extra_step_kwargs
374
-
375
- def check_inputs(
376
- self,
377
- prompt,
378
- callback_steps,
379
- negative_prompt=None,
380
- prompt_embeds=None,
381
- negative_prompt_embeds=None,
382
- ):
383
- if (callback_steps is None) or (
384
- callback_steps is not None
385
- and (not isinstance(callback_steps, int) or callback_steps <= 0)
386
- ):
387
- raise ValueError(
388
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
389
- f" {type(callback_steps)}."
390
- )
391
-
392
- if prompt is not None and prompt_embeds is not None:
393
- raise ValueError(
394
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
395
- " only forward one of the two."
396
- )
397
- elif prompt is None and prompt_embeds is None:
398
- raise ValueError(
399
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
400
- )
401
- elif prompt is not None and (
402
- not isinstance(prompt, str) and not isinstance(prompt, list)
403
- ):
404
- raise ValueError(
405
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
406
- )
407
-
408
- if negative_prompt is not None and negative_prompt_embeds is not None:
409
- raise ValueError(
410
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
411
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
412
- )
413
-
414
- if prompt_embeds is not None and negative_prompt_embeds is not None:
415
- if prompt_embeds.shape != negative_prompt_embeds.shape:
416
- raise ValueError(
417
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
418
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
419
- f" {negative_prompt_embeds.shape}."
420
- )
421
-
422
- def prepare_latents(
423
- self,
424
- batch_size,
425
- num_channels_latents,
426
- height,
427
- width,
428
- dtype,
429
- device,
430
- generator,
431
- latents=None,
432
- ):
433
- shape = (
434
- batch_size,
435
- num_channels_latents,
436
- height // self.vae_scale_factor,
437
- width // self.vae_scale_factor,
438
- )
439
- if isinstance(generator, list) and len(generator) != batch_size:
440
- raise ValueError(
441
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
442
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
443
- )
444
-
445
- if latents is None:
446
- latents = randn_tensor(
447
- shape, generator=generator, device=device, dtype=dtype
448
- )
449
- else:
450
- latents = latents.to(device)
451
-
452
- # scale the initial noise by the standard deviation required by the scheduler
453
- latents = latents * self.scheduler.init_noise_sigma
454
- return latents
455
-
456
- def prepare_image_latents(
457
- self,
458
- image,
459
- batch_size,
460
- num_images_per_prompt,
461
- dtype,
462
- device,
463
- do_classifier_free_guidance,
464
- generator=None,
465
- ):
466
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
467
- raise ValueError(
468
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
469
- )
470
-
471
- image = image.to(device=device, dtype=dtype)
472
-
473
- batch_size = batch_size * num_images_per_prompt
474
-
475
- if image.shape[1] == 4:
476
- image_latents = image
477
- else:
478
- if isinstance(generator, list) and len(generator) != batch_size:
479
- raise ValueError(
480
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
481
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
482
- )
483
-
484
- if isinstance(generator, list):
485
- image_latents = [
486
- self.vae.encode(image[i : i + 1]).latent_dist.mode()
487
- for i in range(batch_size)
488
- ]
489
- image_latents = torch.cat(image_latents, dim=0)
490
- else:
491
- image_latents = self.vae.encode(image).latent_dist.mode()
492
-
493
- if (
494
- batch_size > image_latents.shape[0]
495
- and batch_size % image_latents.shape[0] == 0
496
- ):
497
- # expand image_latents for batch_size
498
- deprecation_message = (
499
- f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
500
- " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
501
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
502
- " your script to pass as many initial images as text prompts to suppress this warning."
503
- )
504
- deprecate(
505
- "len(prompt) != len(image)",
506
- "1.0.0",
507
- deprecation_message,
508
- standard_warn=False,
509
- )
510
- additional_image_per_prompt = batch_size // image_latents.shape[0]
511
- image_latents = torch.cat(
512
- [image_latents] * additional_image_per_prompt, dim=0
513
- )
514
- elif (
515
- batch_size > image_latents.shape[0]
516
- and batch_size % image_latents.shape[0] != 0
517
- ):
518
- raise ValueError(
519
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
520
- )
521
- else:
522
- image_latents = torch.cat([image_latents], dim=0)
523
-
524
- if do_classifier_free_guidance:
525
- uncond_image_latents = torch.zeros_like(image_latents)
526
- image_latents = torch.cat(
527
- [image_latents, image_latents, uncond_image_latents], dim=0
528
- )
529
-
530
- return image_latents
531
-
532
- @torch.no_grad()
533
- def __call__(
534
- self,
535
- prompt: Union[str, List[str]] = None,
536
- photo: Union[
537
- torch.FloatTensor,
538
- PIL.Image.Image,
539
- np.ndarray,
540
- List[torch.FloatTensor],
541
- List[PIL.Image.Image],
542
- List[np.ndarray],
543
- ] = None,
544
- height: Optional[int] = None,
545
- width: Optional[int] = None,
546
- num_inference_steps: int = 100,
547
- required_aovs: List[str] = ["albedo"],
548
- negative_prompt: Optional[Union[str, List[str]]] = None,
549
- num_images_per_prompt: Optional[int] = 1,
550
- use_default_scaling_factor: Optional[bool] = False,
551
- guidance_scale: float = 0.0,
552
- image_guidance_scale: float = 0.0,
553
- guidance_rescale: float = 0.0,
554
- eta: float = 0.0,
555
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
556
- latents: Optional[torch.FloatTensor] = None,
557
- prompt_embeds: Optional[torch.FloatTensor] = None,
558
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
559
- output_type: Optional[str] = "pil",
560
- return_dict: bool = True,
561
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
562
- callback_steps: int = 1,
563
- ):
564
- r"""
565
- The call function to the pipeline for generation.
566
-
567
- Args:
568
- prompt (`str` or `List[str]`, *optional*):
569
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
570
- image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
571
- `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
572
- image latents as `image`, but if passing latents directly it is not encoded again.
573
- num_inference_steps (`int`, *optional*, defaults to 100):
574
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
575
- expense of slower inference.
576
- guidance_scale (`float`, *optional*, defaults to 7.5):
577
- A higher guidance scale value encourages the model to generate images closely linked to the text
578
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
579
- image_guidance_scale (`float`, *optional*, defaults to 1.5):
580
- Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
581
- `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
582
- linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
583
- value of at least `1`.
584
- negative_prompt (`str` or `List[str]`, *optional*):
585
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
586
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
587
- num_images_per_prompt (`int`, *optional*, defaults to 1):
588
- The number of images to generate per prompt.
589
- eta (`float`, *optional*, defaults to 0.0):
590
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
591
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
592
- generator (`torch.Generator`, *optional*):
593
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
594
- generation deterministic.
595
- latents (`torch.FloatTensor`, *optional*):
596
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
597
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
598
- tensor is generated by sampling using the supplied random `generator`.
599
- prompt_embeds (`torch.FloatTensor`, *optional*):
600
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
601
- provided, text embeddings are generated from the `prompt` input argument.
602
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
603
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
604
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
605
- output_type (`str`, *optional*, defaults to `"pil"`):
606
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
607
- return_dict (`bool`, *optional*, defaults to `True`):
608
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
609
- plain tuple.
610
- callback (`Callable`, *optional*):
611
- A function that calls every `callback_steps` steps during inference. The function is called with the
612
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
613
- callback_steps (`int`, *optional*, defaults to 1):
614
- The frequency at which the `callback` function is called. If not specified, the callback is called at
615
- every step.
616
-
617
- Examples:
618
-
619
- ```py
620
- >>> import PIL
621
- >>> import requests
622
- >>> import torch
623
- >>> from io import BytesIO
624
-
625
- >>> from diffusers import StableDiffusionInstructPix2PixPipeline
626
-
627
-
628
- >>> def download_image(url):
629
- ... response = requests.get(url)
630
- ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
631
-
632
-
633
- >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
634
-
635
- >>> image = download_image(img_url).resize((512, 512))
636
-
637
- >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
638
- ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
639
- ... )
640
- >>> pipe = pipe.to("cuda")
641
-
642
- >>> prompt = "make the mountains snowy"
643
- >>> image = pipe(prompt=prompt, image=image).images[0]
644
- ```
645
-
646
- Returns:
647
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
648
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
649
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
650
- second element is a list of `bool`s indicating whether the corresponding generated image contains
651
- "not-safe-for-work" (nsfw) content.
652
- """
653
- # 0. Check inputs
654
- self.check_inputs(
655
- prompt,
656
- callback_steps,
657
- negative_prompt,
658
- prompt_embeds,
659
- negative_prompt_embeds,
660
- )
661
-
662
- # 1. Define call parameters
663
- if prompt is not None and isinstance(prompt, str):
664
- batch_size = 1
665
- elif prompt is not None and isinstance(prompt, list):
666
- batch_size = len(prompt)
667
- else:
668
- batch_size = prompt_embeds.shape[0]
669
-
670
- device = self._execution_device
671
- do_classifier_free_guidance = (
672
- guidance_scale > 1.0 and image_guidance_scale >= 1.0
673
- )
674
- # check if scheduler is in sigmas space
675
- scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
676
-
677
- # 2. Encode input prompt
678
- prompt_embeds = self._encode_prompt(
679
- prompt,
680
- device,
681
- num_images_per_prompt,
682
- do_classifier_free_guidance,
683
- negative_prompt,
684
- prompt_embeds=prompt_embeds,
685
- negative_prompt_embeds=negative_prompt_embeds,
686
- )
687
-
688
- # 3. Preprocess image
689
- # Normalize image to [-1,1]
690
- preprocessed_photo = self.image_processor.preprocess(photo)
691
-
692
- # 4. set timesteps
693
- self.scheduler.set_timesteps(num_inference_steps, device=device)
694
- timesteps = self.scheduler.timesteps
695
-
696
- # 5. Prepare Image latents
697
- image_latents = self.prepare_image_latents(
698
- preprocessed_photo,
699
- batch_size,
700
- num_images_per_prompt,
701
- prompt_embeds.dtype,
702
- device,
703
- do_classifier_free_guidance,
704
- generator,
705
- )
706
- image_latents = image_latents * self.vae.config.scaling_factor
707
-
708
- height, width = image_latents.shape[-2:]
709
- height = height * self.vae_scale_factor
710
- width = width * self.vae_scale_factor
711
-
712
- # 6. Prepare latent variables
713
- num_channels_latents = self.unet.config.out_channels
714
- latents = self.prepare_latents(
715
- batch_size * num_images_per_prompt,
716
- num_channels_latents,
717
- height,
718
- width,
719
- prompt_embeds.dtype,
720
- device,
721
- generator,
722
- latents,
723
- )
724
-
725
- # 7. Check that shapes of latents and image match the UNet channels
726
- num_channels_image = image_latents.shape[1]
727
- if num_channels_latents + num_channels_image != self.unet.config.in_channels:
728
- raise ValueError(
729
- f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
730
- f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
731
- f" `num_channels_image`: {num_channels_image} "
732
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
733
- " `pipeline.unet` or your `image` input."
734
- )
735
-
736
- # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
737
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
738
-
739
- # 9. Denoising loop
740
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
741
- with self.progress_bar(total=num_inference_steps) as progress_bar:
742
- for i, t in enumerate(timesteps):
743
- # Expand the latents if we are doing classifier free guidance.
744
- # The latents are expanded 3 times because for pix2pix the guidance\
745
- # is applied for both the text and the input image.
746
- latent_model_input = (
747
- torch.cat([latents] * 3) if do_classifier_free_guidance else latents
748
- )
749
-
750
- # concat latents, image_latents in the channel dimension
751
- scaled_latent_model_input = self.scheduler.scale_model_input(
752
- latent_model_input, t
753
- )
754
- scaled_latent_model_input = torch.cat(
755
- [scaled_latent_model_input, image_latents], dim=1
756
- )
757
-
758
- # predict the noise residual
759
- noise_pred = self.unet(
760
- scaled_latent_model_input,
761
- t,
762
- encoder_hidden_states=prompt_embeds,
763
- return_dict=False,
764
- )[0]
765
-
766
- # perform guidance
767
- if do_classifier_free_guidance:
768
- (
769
- noise_pred_text,
770
- noise_pred_image,
771
- noise_pred_uncond,
772
- ) = noise_pred.chunk(3)
773
- noise_pred = (
774
- noise_pred_uncond
775
- + guidance_scale * (noise_pred_text - noise_pred_image)
776
- + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
777
- )
778
-
779
- if do_classifier_free_guidance and guidance_rescale > 0.0:
780
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
781
- noise_pred = rescale_noise_cfg(
782
- noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
783
- )
784
-
785
- # compute the previous noisy sample x_t -> x_t-1
786
- latents = self.scheduler.step(
787
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False
788
- )[0]
789
-
790
- # call the callback, if provided
791
- if i == len(timesteps) - 1 or (
792
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
793
- ):
794
- progress_bar.update()
795
- if callback is not None and i % callback_steps == 0:
796
- callback(i, t, latents)
797
-
798
- aov_latents = latents / self.vae.config.scaling_factor
799
- aov = self.vae.decode(aov_latents, return_dict=False)[0]
800
- do_denormalize = [True] * aov.shape[0]
801
- aov_name = required_aovs[0]
802
- if aov_name == "albedo" or aov_name == "irradiance":
803
- do_gamma_correction = True
804
- else:
805
- do_gamma_correction = False
806
-
807
- if aov_name == "roughness" or aov_name == "metallic":
808
- aov = aov[:, 0:1].repeat(1, 3, 1, 1)
809
-
810
- aov = self.image_processor.postprocess(
811
- aov,
812
- output_type=output_type,
813
- do_denormalize=do_denormalize,
814
- do_gamma_correction=do_gamma_correction,
815
- )
816
- aovs = [aov]
817
-
818
- # Offload last model to CPU
819
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
820
- self.final_offload_hook.offload()
821
- return StableDiffusionAOVPipelineOutput(images=aovs)