import ast from safetensors import safe_open import torch from dataclasses import dataclass from typing import Optional, Union, List def update_args_from_yaml(group, args, parser): for key, value in group.items(): if isinstance(value, dict): update_args_from_yaml(value, args, parser) else: if value == 'None' or value == 'null': value = None else: arg_type = next((action.type for action in parser._actions if action.dest == key), str) if arg_type is ast.literal_eval: pass elif arg_type is not None and not isinstance(value, arg_type): try: value = arg_type(value) except ValueError as e: raise ValueError(f"Cannot convert {key} to {arg_type}: {e}") setattr(args, key, value) def safe_load(model_path): assert "safetensors" in model_path state_dict = {} with safe_open(model_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) return state_dict @dataclass class DDIMSchedulerStepOutput: prev_sample: torch.Tensor # x_{t-1} pred_original_sample: Optional[torch.Tensor] = None # x0 @dataclass class DDIMSchedulerConversionOutput: pred_epsilon: torch.Tensor pred_original_sample: torch.Tensor pred_velocity: torch.Tensor class DDIMScheduler: prediction_types = ["epsilon", "sample", "v_prediction"] def __init__( self, num_train_timesteps: int, num_inference_timesteps: int, betas: torch.Tensor, set_alpha_to_one: bool = True, set_inference_timesteps_from_pure_noise: bool = True, inference_timesteps: Union[str, List[int]] = "trailing", device: Optional[Union[str, torch.device]] = None, dtype: torch.dtype = torch.float32, skip_step:bool = False, original_inference_step: int=20, steps_offset: int=0, ): assert num_train_timesteps > 0 assert num_train_timesteps >= num_inference_timesteps assert num_train_timesteps == betas.size(0) assert betas.ndim == 1 # self.user_name = user_name # self.run_time = Recorder.format_time() # self.task_name = 'AutoAIGC_%s' % str(self.run_time) self.module_name = 'AutoAIGC' self.config_list = {"num_train_timesteps": num_train_timesteps, "num_inference_timesteps": num_inference_timesteps, "betas": betas, "set_alpha_to_one": set_alpha_to_one, "set_inference_timesteps_from_pure_noise": set_inference_timesteps_from_pure_noise, "inference_timesteps": inference_timesteps} self.module_info = str(self.config_list) # self.upload_logger(user_name=user_name) device = device or betas.device self.num_train_timesteps = num_train_timesteps self.num_inference_steps = num_inference_timesteps self.steps_offset = steps_offset self.betas = betas # .to(device=device, dtype=dtype) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.final_alpha_cumprod = torch.tensor(1.0, device=device, dtype=dtype) if set_alpha_to_one else self.alphas_cumprod[0] if isinstance(inference_timesteps, torch.Tensor): assert len(inference_timesteps) == num_inference_timesteps self.timesteps = inference_timesteps.cpu().numpy().tolist() elif set_inference_timesteps_from_pure_noise: if inference_timesteps == "trailing": # [999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 499, 449, 399, 349, 299, 249, 199, 149, 99, 49] if skip_step: # ? original_timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / original_inference_step, device=device).round().int().tolist() skipping_step = len(original_timesteps) // num_inference_timesteps self.timesteps = original_timesteps[::skipping_step][:num_inference_timesteps] else: # [999, 899, 799, 699, 599, 499, 399, 299, 199, 99] self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=device).round().int().tolist() elif inference_timesteps == "linspace": # Fixed DDIM timestep. Make sure the timestep starts from 999. # Example 20 steps: # [999, 946, 894, 841, 789, 736, 684, 631, 578, 526, 473, 421, 368, 315, 263, 210, 158, 105, 53, 0] # [999, 888, 777, 666, 555, 444, 333, 222, 111, 0] self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=device).round().int().flip(0).tolist() elif inference_timesteps == "leading": step_ratio = num_train_timesteps // num_inference_timesteps # # creates integer timesteps by multiplying by ratio # # casting to int to avoid issues when num_inference_step is power of 3 self.timesteps = torch.arange(0, num_inference_timesteps).mul(step_ratio).round().flip(dims=[0]) #.clone().long() # self.timesteps += self.steps_offset # Original SD and DDIM paper may have a bug: # The inference timestep does not start from 999. # Example 20 steps: # [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0] # [ 900, 800, 700, 600, 500, 400, 300, 200, 100, 0] # self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0) # self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps))) else: raise NotImplementedError elif inference_timesteps == "leading": # Original SD and DDIM paper may have a bug: # The inference timestep does not start from 999. # Example 20 steps: # [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0] # [ 900, 800, 700, 600, 500, 400, 300, 200, 100, 0] # self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0) self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps))) else: self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps))) # raise NotImplementedError self.to(device=device) def to(self, device): self.betas = self.betas.to(device) self.alphas_cumprod = self.alphas_cumprod.to(device) self.final_alpha_cumprod = self.final_alpha_cumprod.to(device) # self.timesteps = self.timesteps.to(device) return self def step( self, model_output: torch.Tensor, model_output_type: str, timestep: Union[torch.Tensor, int], sample: torch.Tensor, eta: float = 0.0, clip_sample: bool = False, dynamic_threshold: Optional[float] = None, variance_noise: Optional[torch.Tensor] = None, ) -> DDIMSchedulerStepOutput: # 1. get previous step value (t-1) if isinstance(timestep, int): # 1. get previous step value (t-1) idx = self.timesteps.index(timestep) prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev else: timesteps = torch.tensor(self.timesteps).to(timestep.device) idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx # 根据idx找到idx+1对应的timesteps元素,也就是下一个时间步。如果idx+1超出了timesteps的长度,它会被限制在self.num_inference_steps - 1 prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)] assert (prev_timestep is not None) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev bs = timestep.size(0) alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1) alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1) beta_prod_t = beta_prod_t.view(bs, 1, 1, 1) beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1) # # 2. compute alphas, betas # alpha_prod_t = self.alphas_cumprod[timestep] # alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod # beta_prod_t = 1 - alpha_prod_t # beta_prod_t_prev = 1 - alpha_prod_t_prev # rcfg self.stock_alpha_prod_t_prev = alpha_prod_t_prev self.stock_beta_prod_t_prev = beta_prod_t_prev # rcfg self.stock_alpha_prod_t_prev = alpha_prod_t_prev self.stock_beta_prod_t_prev = beta_prod_t_prev # 3. compute predicted original sample from predicted noise also called model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep) pred_original_sample = model_output_conversion.pred_original_sample pred_epsilon = model_output_conversion.pred_epsilon # 4. Clip or threshold "predicted x_0" if clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon if dynamic_threshold is not None: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = pred_original_sample \ .flatten(1) \ .abs() \ .float() \ .quantile(dynamic_threshold, dim=1) \ .type_as(pred_original_sample) \ .clamp_min(1) \ .view(-1, *([1] * (pred_original_sample.ndim - 1))) pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon # 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction # 8. add "random noise" if needed. if eta > 0: if variance_noise is None: variance_noise = torch.randn_like(model_output) prev_sample = prev_sample + std_dev_t * variance_noise return DDIMSchedulerStepOutput( prev_sample=prev_sample, # x_{t-1} pred_original_sample=pred_original_sample # x0 ) def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: Union[torch.Tensor, int], replace_noise=True ) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1))) if replace_noise: indices = (timesteps == 999).nonzero() if indices.numel() > 0: alpha_prod_t[indices] = 0 return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise def add_noise_lcm( self, original_samples: torch.Tensor, noise: torch.Tensor, timestep: Union[torch.Tensor, int], ) -> torch.Tensor: if isinstance(timestep, int): # 1. get previous step value (t-1) idx = self.timesteps.index(timestep) prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev else: timesteps = torch.tensor(self.timesteps).to(timestep.device) idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)] assert (prev_timestep is not None) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev bs = timestep.size(0) alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1) alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1) beta_prod_t = beta_prod_t.view(bs, 1, 1, 1) beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1) alpha_prod_t_prev = alpha_prod_t_prev.reshape(-1, *([1] * (original_samples.ndim - 1))) return alpha_prod_t_prev ** (0.5) * original_samples + (1 - alpha_prod_t_prev) ** (0.5) * noise def convert_output( self, model_output: torch.Tensor, model_output_type: str, sample: torch.Tensor, timesteps: Union[torch.Tensor, int] ) -> DDIMSchedulerConversionOutput: assert model_output_type in self.prediction_types alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1))) beta_prod_t = 1 - alpha_prod_t if model_output_type == "epsilon": pred_epsilon = model_output pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5) pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample elif model_output_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample elif model_output_type == "v_prediction": pred_velocity = model_output pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError("Unknown prediction type") return DDIMSchedulerConversionOutput( pred_epsilon=pred_epsilon, pred_original_sample=pred_original_sample, pred_velocity=pred_velocity) def get_velocity( self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor ) -> torch.FloatTensor: alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1))) return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample