import argparse import hashlib import json import os import time from typing import TYPE_CHECKING, Union, List import sys from diffusers import ( DDPMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler ) import torch import re from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL TEXT_ENCODER_2_PROJECTION_DIM = 1280 UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 def get_torch_dtype(dtype_str): # if it is a torch dtype, return it if isinstance(dtype_str, torch.dtype): return dtype_str if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32": return torch.float if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16": return torch.float16 if dtype_str == "bf16" or dtype_str == "bfloat16": return torch.bfloat16 if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8": return torch.float8_e4m3fn return dtype_str def replace_filewords_prompt(prompt, args: argparse.Namespace): # if name_replace attr in args (may not be) if hasattr(args, "name_replace") and args.name_replace is not None: # replace [name] to args.name_replace prompt = prompt.replace("[name]", args.name_replace) if hasattr(args, "prepend") and args.prepend is not None: # prepend to every item in prompt file prompt = args.prepend + ' ' + prompt if hasattr(args, "append") and args.append is not None: # append to every item in prompt file prompt = prompt + ' ' + args.append return prompt def replace_filewords_in_dataset_group(dataset_group, args: argparse.Namespace): # if name_replace attr in args (may not be) if hasattr(args, "name_replace") and args.name_replace is not None: if not len(dataset_group.image_data) > 0: # throw error raise ValueError("dataset_group.image_data is empty") for key in dataset_group.image_data: dataset_group.image_data[key].caption = dataset_group.image_data[key].caption.replace( "[name]", args.name_replace) return dataset_group def get_seeds_from_latents(latents): # latents shape = (batch_size, 4, height, width) # for speed we only use 8x8 slice of the first channel seeds = [] # split batch up for i in range(latents.shape[0]): # use only first channel, multiply by 255 and convert to int tensor = latents[i, 0, :, :] * 255.0 # shape = (height, width) # slice 8x8 tensor = tensor[:8, :8] # clip to 0-255 tensor = torch.clamp(tensor, 0, 255) # convert to 8bit int tensor = tensor.to(torch.uint8) # convert to bytes tensor_bytes = tensor.cpu().numpy().tobytes() # hash hash_object = hashlib.sha256(tensor_bytes) # get hex hex_dig = hash_object.hexdigest() # convert to int seed = int(hex_dig, 16) % (2 ** 32) # append seeds.append(seed) return seeds def get_noise_from_latents(latents): seed_list = get_seeds_from_latents(latents) noise = [] for seed in seed_list: torch.manual_seed(seed) torch.cuda.manual_seed(seed) noise.append(torch.randn_like(latents[0])) return torch.stack(noise) # mix 0 is completely noise mean, mix 1 is completely target mean def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None): dim = dim or (1, 2, 3) # reduce mean of noise on dim 2, 3, keeping 0 and 1 intact noise_mean = noise.mean(dim=dim, keepdim=True) target_mean = target.mean(dim=dim, keepdim=True) new_noise_mean = mix * target_mean + (1 - mix) * noise_mean noise = noise - noise_mean + new_noise_mean return noise # https://www.crosslabs.org//blog/diffusion-with-offset-noise def apply_noise_offset(noise, noise_offset): if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001): return noise if len(noise.shape) > 4: raise ValueError("Applying noise offset not supported for video models at this time.") noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) return noise if TYPE_CHECKING: from toolkit.stable_diffusion_model import PromptEmbeds def concat_prompt_embeddings( unconditional: 'PromptEmbeds', conditional: 'PromptEmbeds', n_imgs: int=0, ): from toolkit.stable_diffusion_model import PromptEmbeds text_embeds = torch.cat( [unconditional.text_embeds, conditional.text_embeds] ).repeat_interleave(n_imgs, dim=0) pooled_embeds = None if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None: pooled_embeds = torch.cat( [unconditional.pooled_embeds, conditional.pooled_embeds] ).repeat_interleave(n_imgs, dim=0) return PromptEmbeds([text_embeds, pooled_embeds]) def addnet_hash_safetensors(b): """New model hash used by sd-webui-additional-networks for .safetensors format files""" hash_sha256 = hashlib.sha256() blksize = 1024 * 1024 b.seek(0) header = b.read(8) n = int.from_bytes(header, "little") offset = n + 8 b.seek(offset) for chunk in iter(lambda: b.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest() def addnet_hash_legacy(b): """Old model hash used by sd-webui-additional-networks for .safetensors format files""" m = hashlib.sha256() b.seek(0x100000) m.update(b.read(0x10000)) return m.hexdigest()[0:8] if TYPE_CHECKING: from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection def text_tokenize( tokenizer: 'CLIPTokenizer', prompts: list[str], truncate: bool = True, max_length: int = None, max_length_multiplier: int = 4, ): # allow fo up to 4x the max length for long prompts if max_length is None: if truncate: max_length = tokenizer.model_max_length else: # allow up to 4x the max length for long prompts max_length = tokenizer.model_max_length * max_length_multiplier input_ids = tokenizer( prompts, padding='max_length', max_length=max_length, truncation=True, return_tensors="pt", ).input_ids if truncate or max_length == tokenizer.model_max_length: return input_ids else: # remove additional padding num_chunks = input_ids.shape[1] // tokenizer.model_max_length chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1) # New list to store non-redundant chunks non_redundant_chunks = [] for chunk in chunks: if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element non_redundant_chunks.append(chunk) input_ids = torch.cat(non_redundant_chunks, dim=1) return input_ids # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 def text_encode_xl( text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'], tokens: torch.FloatTensor, num_images_per_prompt: int = 1, max_length: int = 77, # not sure what default to put here, always pass one? truncate: bool = True, ): if truncate: # normal short prompt 77 tokens max prompt_embeds = text_encoder( tokens.to(text_encoder.device), output_hidden_states=True ) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer else: # handle long prompts prompt_embeds_list = [] tokens = tokens.to(text_encoder.device) pooled_prompt_embeds = None for i in range(0, tokens.shape[-1], max_length): # todo run it through the in a single batch section_tokens = tokens[:, i: i + max_length] embeds = text_encoder(section_tokens, output_hidden_states=True) pooled_prompt_embed = embeds[0] if pooled_prompt_embeds is None: # we only want the first ( I think??) pooled_prompt_embeds = pooled_prompt_embed prompt_embed = embeds.hidden_states[-2] # always penultimate layer prompt_embeds_list.append(prompt_embed) prompt_embeds = torch.cat(prompt_embeds_list, dim=1) bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) return prompt_embeds, pooled_prompt_embeds def encode_prompts_xl( tokenizers: list['CLIPTokenizer'], text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']], prompts: list[str], prompts2: Union[list[str], None], num_images_per_prompt: int = 1, use_text_encoder_1: bool = True, # sdxl use_text_encoder_2: bool = True, # sdxl truncate: bool = True, max_length=None, dropout_prob=0.0, ) -> tuple[torch.FloatTensor, torch.FloatTensor]: # text_encoder and text_encoder_2's penuultimate layer's output text_embeds_list = [] pooled_text_embeds = None # always text_encoder_2's pool if prompts2 is None: prompts2 = prompts for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)): # todo, we are using a blank string to ignore that encoder for now. # find a better way to do this (zeroing?, removing it from the unet?) prompt_list_to_use = prompts if idx == 0 else prompts2 if idx == 0 and not use_text_encoder_1: prompt_list_to_use = ["" for _ in prompts] if idx == 1 and not use_text_encoder_2: prompt_list_to_use = ["" for _ in prompts] if dropout_prob > 0.0: # randomly drop out prompts prompt_list_to_use = [ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use ] text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length) # set the max length for the next one if idx == 0: max_length = text_tokens_input_ids.shape[-1] text_embeds, pooled_text_embeds = text_encode_xl( text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length, truncate=truncate ) text_embeds_list.append(text_embeds) bs_embed = pooled_text_embeds.shape[0] pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds def encode_prompts_sd3( tokenizers: list['CLIPTokenizer'], text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]], prompts: list[str], num_images_per_prompt: int = 1, truncate: bool = True, max_length=None, dropout_prob=0.0, pipeline = None, ): text_embeds_list = [] pooled_text_embeds = None # always text_encoder_2's pool prompt_2 = prompts prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompts prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 device = text_encoders[0].device prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds( prompt=prompts, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0, ) prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds( prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = pipeline._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, device=device ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) return prompt_embeds, pooled_prompt_embeds # ref for long prompts https://github.com/huggingface/diffusers/issues/2136 def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None): if max_length is None and not truncate: raise ValueError("max_length must be set if truncate is True") try: tokens = tokens.to(text_encoder.device) except Exception as e: print(e) print("tokens.device", tokens.device) print("text_encoder.device", text_encoder.device) raise e if truncate: return text_encoder(tokens)[0] else: # handle long prompts prompt_embeds_list = [] for i in range(0, tokens.shape[-1], max_length): prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0] prompt_embeds_list.append(prompt_embeds) return torch.cat(prompt_embeds_list, dim=1) def encode_prompts( tokenizer: 'CLIPTokenizer', text_encoder: 'CLIPTextModel', prompts: list[str], truncate: bool = True, max_length=None, dropout_prob=0.0, ): if max_length is None: max_length = tokenizer.model_max_length if dropout_prob > 0.0: # randomly drop out prompts prompts = [ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts ] text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length) text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length) return text_embeddings def encode_prompts_pixart( tokenizer: 'T5Tokenizer', text_encoder: 'T5EncoderModel', prompts: list[str], truncate: bool = True, max_length=None, dropout_prob=0.0, ): if max_length is None: # See Section 3.1. of the paper. max_length = 120 if dropout_prob > 0.0: # randomly drop out prompts prompts = [ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts ] text_inputs = tokenizer( prompts, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.to(text_encoder.device) text_input_ids = text_input_ids.to(text_encoder.device) prompt_embeds = text_encoder(text_input_ids, attention_mask=prompt_attention_mask) return prompt_embeds.last_hidden_state, prompt_attention_mask def encode_prompts_auraflow( tokenizer: 'T5Tokenizer', text_encoder: 'UMT5EncoderModel', prompts: list[str], truncate: bool = True, max_length=None, dropout_prob=0.0, ): if max_length is None: max_length = 256 if dropout_prob > 0.0: # randomly drop out prompts prompts = [ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts ] device = text_encoder.device text_inputs = tokenizer( prompts, truncation=True, max_length=max_length, padding="max_length", return_tensors="pt", ) text_input_ids = text_inputs["input_ids"] untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) text_inputs = {k: v.to(device) for k, v in text_inputs.items()} prompt_embeds = text_encoder(**text_inputs)[0] prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) prompt_embeds = prompt_embeds * prompt_attention_mask return prompt_embeds, prompt_attention_mask def encode_prompts_flux( tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']], text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']], prompts: list[str], truncate: bool = True, max_length=None, dropout_prob=0.0, attn_mask: bool = False, ): if max_length is None: max_length = 512 if dropout_prob > 0.0: # randomly drop out prompts prompts = [ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts ] device = text_encoder[0].device dtype = text_encoder[0].dtype batch_size = len(prompts) # clip text_inputs = tokenizer[0]( prompts, padding="max_length", max_length=tokenizer[0].model_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel pooled_prompt_embeds = prompt_embeds.pooler_output pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device) # T5 text_inputs = tokenizer[1]( prompts, padding="max_length", max_length=max_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] dtype = text_encoder[1].dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) if attn_mask: prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device) return prompt_embeds, pooled_prompt_embeds # for XL def get_add_time_ids( height: int, width: int, dynamic_crops: bool = False, dtype: torch.dtype = torch.float32, ): if dynamic_crops: # random float scale between 1 and 3 random_scale = torch.rand(1).item() * 2 + 1 original_size = (int(height * random_scale), int(width * random_scale)) # random position crops_coords_top_left = ( torch.randint(0, original_size[0] - height, (1,)).item(), torch.randint(0, original_size[1] - width, (1,)).item(), ) target_size = (height, width) else: original_size = (height, width) crops_coords_top_left = (0, 0) target_size = (height, width) # this is expected as 6 add_time_ids = list(original_size + crops_coords_top_left + target_size) # this is expected as 2816 passed_add_embed_dim = ( UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 ) if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: raise ValueError( f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids def concat_embeddings( unconditional: torch.FloatTensor, conditional: torch.FloatTensor, n_imgs: int, ): return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) def add_all_snr_to_noise_scheduler(noise_scheduler, device): try: if hasattr(noise_scheduler, "all_snr"): return # compute it with torch.no_grad(): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod sigma = sqrt_one_minus_alphas_cumprod all_snr = (alpha / sigma) ** 2 all_snr.requires_grad = False noise_scheduler.all_snr = all_snr.to(device) except Exception as e: # just move on pass def get_all_snr(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return noise_scheduler.all_snr.to(device) # compute it with torch.no_grad(): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod sigma = sqrt_one_minus_alphas_cumprod all_snr = (alpha / sigma) ** 2 all_snr.requires_grad = False return all_snr.to(device) class LearnableSNRGamma: """ This is a trainer for learnable snr gamma It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps """ def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'): self.device = device self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device)) self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device)) self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device)) self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device)) self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01) self.buffer = [] self.max_buffer_size = 20 def forward(self, loss, timesteps): # do a our train loop for lsnr here and return our values detached loss = loss.detach() with torch.no_grad(): loss_chunks = torch.chunk(loss, loss.shape[0], dim=0) for loss_chunk in loss_chunks: self.buffer.append(loss_chunk.mean().detach()) if len(self.buffer) > self.max_buffer_size: self.buffer.pop(0) all_snr = get_all_snr(self.noise_scheduler, loss.device) snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device) base_snrs = snr.clone().detach() snr.requires_grad = True snr = (snr + self.offset_1) * self.scale + self.offset_2 gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr) snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr snr_adjusted_loss = loss * snr_weight with torch.no_grad(): target = torch.mean(torch.stack(self.buffer)).detach() # local_loss = torch.mean(torch.abs(snr_adjusted_loss - target)) squared_differences = (snr_adjusted_loss - target) ** 2 local_loss = torch.mean(squared_differences) local_loss.backward() self.optimizer.step() self.optimizer.zero_grad() return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach() def apply_learnable_snr_gos( loss, timesteps, learnable_snr_trainer: LearnableSNRGamma ): snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps) snr = (snr + offset_1) * scale + offset_2 gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr snr_adjusted_loss = loss * snr_weight return snr_adjusted_loss def apply_snr_weight( loss, timesteps, noise_scheduler: Union['DDPMScheduler'], gamma, fixed=False, ): # will get it from noise scheduler if exist or will calculate it if not all_snr = get_all_snr(noise_scheduler, loss.device) # step_indices = [] # for t in timesteps: # for i, st in enumerate(noise_scheduler.timesteps): # if st == t: # step_indices.append(i) # break # this breaks on some schedulers # step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps] offset = 0 if noise_scheduler.timesteps[0] == 1000: offset = 1 snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) if fixed: snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr else: snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) snr_adjusted_loss = loss * snr_weight return snr_adjusted_loss def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler): mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0) mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) out_chunks = [] # unsqueeze if timestep is zero dim for idx in range(model_output.shape[0]): sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx] out_chunks.append(out) return torch.cat(out_chunks, dim=0)