import torch from diffusers import LCMScheduler from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import * class Hack_SDPipe_Stepwise(StableDiffusionPipeline): @torch.no_grad() def _use_lcm(self,use=True,ckpt='"latent-consistency/lcm-lora-sdv1-5"'): if use: self.use_lcm = True adapter_id = ckpt self.scheduler = LCMScheduler.from_config(self.scheduler.config) # load and fuse lcm lora self._guidance_scale = 0.0 self.load_lora_weights(adapter_id) self.fuse_lora() else: self.use_lcm = False self._guidance_scale = 7.5 @torch.no_grad() def re_init(self,num_inference_steps=50): # hyper-parameters eta = 0.0 timesteps = None generator = None self._clip_skip = None self._interrupt = False self._guidance_rescale = 0.0 self.added_cond_kwargs = None self._cross_attention_kwargs = None self._do_classifier_free_guidance = self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None # 2. Define call parameters batch_size = 1 device = self._execution_device # 4. Prepare timesteps self.timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline self.extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.2 Optionally get Guidance Scale Embedding self.timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * 1) self.timestep_cond = self.get_guidance_scale_embedding(guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim).to(device=device) @torch.no_grad() def _encode_text_prompt(self, prompt, negative_prompt='fake,ugly,unreal'): # 3. Encode input prompt lora_scale = (self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, self._execution_device, 1, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=None, negative_prompt_embeds=None, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds @torch.no_grad() def _step_noise(self, latents, time_step, prompt_embeds): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_step) # predict the noise residual noise_pred = self.unet( latent_model_input, time_step, encoder_hidden_states=prompt_embeds, timestep_cond=self.timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=self.added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) return noise_pred # @torch.no_grad() def _encode(self, input): ''' # single condition encoding input: B3HW return: B4H'W' if low-vram: vae on cpu, input should also on cpu ''' h = self.vae.encoder(input) moments = self.vae.quant_conv(h) mean, logvar = torch.chunk(moments, 2, dim=1) # scale latent latent = mean * self.vae.config.scaling_factor return latent def _decode(self, latent): ''' single target decoding input: B4H'W' return: B3HW ''' # scale latent latent = latent / self.vae.config.scaling_factor # decode z = self.vae.post_quant_conv(latent) output = self.vae.decoder(z) return output def _solve_x0_full_step(self, latents, noise_pred, t): self.alpha_t = torch.sqrt(self.scheduler.alphas_cumprod).to(t.device) self.sigma_t = torch.sqrt(1-self.scheduler.alphas_cumprod).to(t.device) a_t, s_t = self.alpha_t[t], self.sigma_t[t] x0_latents = (latents - s_t * noise_pred) / a_t x0 = self._decode(x0_latents) return x0_latents, x0 def _solve_x0(self, latents, noise_pred, t): x0_latents = self.scheduler.step(noise_pred, t.squeeze(), latents) # note here must be a fake denoise self.scheduler._step_index-=1 # results x0_latents = x0_latents.denoised if self.use_lcm else x0_latents.pred_original_sample x0 = self._decode(x0_latents) return x0_latents, x0 def _step_denoise(self, latents, noise_pred, t): latents = self.scheduler.step(noise_pred, t.squeeze(), latents).prev_sample return latents def xt_x0_noise( self, xt_latents: torch.Tensor, x0_latents: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls alphas_cumprod = self.scheduler.alphas_cumprod.to(dtype=xt_latents.dtype,device=xt_latents.device) timesteps = timesteps.to(xt_latents.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(xt_latents.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(xt_latents.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noise = (xt_latents - sqrt_alpha_prod * x0_latents) / sqrt_one_minus_alpha_prod return noise def _solve_noise_given_x0_latent(self, latents, x0_latents, t): noise = self.xt_x0_noise(latents,x0_latents,t) # -------------------- noise for supervision ----------------- if self.scheduler.config.prediction_type == "epsilon": noise = noise elif self.scheduler.config.prediction_type == "v_prediction": noise = self.scheduler.get_velocity(x0_latents, noise, t) # ------------------------------------------------------------ return noise