Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional, Tuple, Union | |
import torch | |
from diffusers import DiffusionPipeline | |
from diffusers.schedulers.scheduling_utils import SchedulerMixin | |
from tqdm import tqdm | |
from fourm.utils import to_2tuple | |
def rescale_noise_cfg(noise_cfg, noise_pred_conditional, guidance_rescale=0.0): | |
""" | |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
""" | |
std_text = noise_pred_conditional.std(dim=list(range(1, noise_pred_conditional.ndim)), keepdim=True) | |
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
# rescale the results from guidance (fixes overexposure) | |
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
return noise_cfg | |
class PipelineCond(DiffusionPipeline): | |
"""Pipeline for conditional image generation. | |
This model inherits from `DiffusionPipeline`. Check the superclass documentation for the generic methods | |
implemented for all pipelines (downloading, saving, running on a particular device, etc.). | |
Args: | |
model: The conditional diffusion model. | |
scheduler: A diffusion scheduler, e.g. see scheduling_ddpm.py | |
""" | |
def __init__(self, model: torch.nn.Module, scheduler: SchedulerMixin): | |
super().__init__() | |
self.register_modules(model=model, scheduler=scheduler) | |
def __call__(self, | |
cond: torch.Tensor, | |
generator: Optional[torch.Generator] = None, | |
timesteps: Optional[int] = None, | |
guidance_scale: float = 0.0, | |
guidance_rescale: float = 0.0, | |
image_size: Optional[Union[Tuple[int, int], int]] = None, | |
verbose: bool = True, | |
scheduler_timesteps_mode: str = 'trailing', | |
orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None, | |
**kwargs) -> torch.Tensor: | |
"""The call function to the pipeline for conditional image generation. | |
Args: | |
cond: The conditional input to the model. | |
generator: A torch.Generator to make generation deterministic. | |
timesteps: The number of denoising steps. More denoising steps usually lead to a higher | |
quality image at the expense of slower inference. Defaults to the number of training | |
timesteps if not given. | |
guidance_scale: The scale of the classifier-free guidance. If set to 0.0, no guidance is used. | |
guidance_rescale: Rescaling factor to fix the variance when using guidance scaling. | |
image_size: The size of the image to generate. If not given, the default training size | |
of the model is used. | |
verbose: Whether to show a progress bar. | |
scheduler_timesteps_mode: The mode to use for DDIMScheduler. One of `trailing`, `linspace`, | |
`leading`. See https://arxiv.org/abs/2305.08891 for more details. | |
orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
Returns: | |
The generated image. | |
""" | |
timesteps = self.scheduler.config.num_train_timesteps if timesteps is None else timesteps | |
batch_size, _, _, _ = cond.shape | |
# Sample gaussian noise to begin loop | |
image_size = self.model.sample_size if image_size is None else image_size | |
image_size = to_2tuple(image_size) | |
image = torch.randn( | |
(batch_size, self.model.in_channels, image_size[0], image_size[1]), | |
generator=generator, | |
) | |
image = image.to(self.model.device) | |
do_cfg = callable(guidance_scale) or guidance_scale > 1.0 | |
# Set step values | |
self.scheduler.set_timesteps(timesteps, mode=scheduler_timesteps_mode) | |
if verbose: | |
pbar = tqdm(total=len(self.scheduler.timesteps)) | |
for t in self.scheduler.timesteps: | |
# 1. Predict noise model_output | |
model_output = self.model(image, t, cond, orig_res=orig_res, **kwargs) | |
if do_cfg: | |
model_output_uncond = self.model(image, t, cond, unconditional=True, **kwargs) # TODO: is there a better way to get unconditional output? | |
if callable(guidance_scale): | |
guidance_scale_value = guidance_scale(t/self.scheduler.config.num_train_timesteps) | |
else: | |
guidance_scale_value = guidance_scale | |
model_output_cfg = model_output_uncond + guidance_scale_value * (model_output - model_output_uncond) | |
if guidance_rescale > 0.0: | |
model_output = rescale_noise_cfg(model_output_cfg, model_output, guidance_rescale=guidance_rescale) | |
else: | |
model_output = model_output_cfg | |
# 2. Compute previous image: x_t -> t_t-1 | |
with torch.cuda.amp.autocast(enabled=False): | |
image = self.scheduler.step(model_output.float(), t, image, generator=generator).prev_sample | |
if verbose: | |
pbar.update() | |
if verbose: | |
pbar.close() | |
return image | |