import os import sys sys.path.append(os.getcwd()) import yaml import copy import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Optional import numpy as np import lpips from torchvision import transforms from PIL import Image from peft import LoraConfig, get_peft_model from copy import deepcopy from tqdm import tqdm from diffusers import StableDiffusion3Pipeline, FluxPipeline from lora.lora_layers import LoraInjectedLinear, LoraInjectedConv2d def inject_lora_vae(vae, lora_rank=4, init_lora_weights="gaussian", verbose=False): """ Inject LoRA into the VAE's encoder """ vae.requires_grad_(False) vae.train() # Identify modules to LoRA-ify in the encoder l_grep = ["conv1", "conv2", "conv_in", "conv_shortcut", "conv", "conv_out", "to_k", "to_q", "to_v", "to_out.0"] l_target_modules_encoder = [] for n, p in vae.named_parameters(): if "bias" in n or "norm" in n: continue for pattern in l_grep: if (pattern in n) and ("encoder" in n): l_target_modules_encoder.append(n.replace(".weight", "")) elif ("quant_conv" in n) and ("post_quant_conv" not in n): l_target_modules_encoder.append(n.replace(".weight", "")) if verbose: print("The following VAE parameters will get LoRA:") print(l_target_modules_encoder) # Create and add a LoRA adapter lora_conf_encoder = LoraConfig( r=lora_rank, init_lora_weights=init_lora_weights, target_modules=l_target_modules_encoder ) adapter_name = "default_encoder" try: vae.add_adapter(lora_conf_encoder, adapter_name=adapter_name) vae.set_adapter(adapter_name) except ValueError as e: if "already exists" in str(e): print(f"Adapter with name {adapter_name} already exists. Skipping injection.") else: raise e return vae, l_target_modules_encoder def _find_modules(model, ancestor_class=None, search_class=[nn.Linear], exclude_children_of=[LoraInjectedLinear]): # Get the targets we should replace all linears under if ancestor_class is not None: ancestors = ( module for module in model.modules() if module.__class__.__name__ in ancestor_class ) else: # this, in case you want to naively iterate over all modules. ancestors = [module for module in model.modules()] for ancestor in ancestors: for fullname, module in ancestor.named_modules(): if any([isinstance(module, _class) for _class in search_class]): *path, name = fullname.split(".") parent = ancestor while path: parent = parent.get_submodule(path.pop(0)) if exclude_children_of and any( [isinstance(parent, _class) for _class in exclude_children_of] ): continue yield parent, name, module def inject_lora(model, ancestor_class, loras=None, r:int=4, dropout_p:float=0.0, scale:float=1.0, verbose:bool=False): model.requires_grad_(False) model.train() names = [] require_grad_params = [] # to be updated total_lora_params = 0 if loras is not None: loras = torch.load(loras, map_location=model.device, weights_only=True) loras = [lora.float() for lora in loras] for _module, name, _child_module in _find_modules(model, ancestor_class): # SiLU + Linear Block weight = _child_module.weight bias = _child_module.bias if verbose: print(f'LoRA Injection : injecting lora into {name}') _tmp = LoraInjectedLinear( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, r=r, dropout_p=dropout_p, scale=scale, ) _tmp.linear.weight = nn.Parameter(weight.float()) if bias is not None: _tmp.linear.bias = nn.Parameter(bias.float()) # switch the module _tmp.to(device=_child_module.weight.device, dtype=torch.float) # keep as float / mixed precision _module._modules[name] = _tmp require_grad_params.append(_module._modules[name].lora_up.parameters()) require_grad_params.append(_module._modules[name].lora_down.parameters()) if loras != None: _module._modules[name].lora_up.weight = nn.Parameter(loras.pop(0)) _module._modules[name].lora_down.weight = nn.Parameter(loras.pop(0)) _module._modules[name].lora_up.weight.requires_grad = True _module._modules[name].lora_down.weight.requires_grad = True names.append(name) if verbose: # -------- Count LoRA parameters just added -------- lora_up_count = sum(p.numel() for p in _tmp.lora_up.parameters()) lora_down_count = sum(p.numel() for p in _tmp.lora_down.parameters()) lora_total_for_this_layer = lora_up_count + lora_down_count total_lora_params += lora_total_for_this_layer print(f" Added {lora_total_for_this_layer} params " f"(lora_up={lora_up_count}, lora_down={lora_down_count})") if verbose: print(f"Total new LoRA parameters added: {total_lora_params}") return require_grad_params, names def add_mp_hook(transformer): ''' For mixed precision of LoRA. (i.e. keep LoRA as float and others as half) ''' def pre_hook(module, input): return input.float() def post_hook(module, input, output): return output.half() hooks = [] for _module, name, _child_module in _find_modules(transformer): if isinstance(_child_module, LoraInjectedLinear): hook = _child_module.lora_up.register_forward_pre_hook(pre_hook) hooks.append(hook) hook = _child_module.lora_down.register_forward_hook(post_hook) hooks.append(hook) return transformer, hooks def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = 0.0, logit_std: float = 1.0, mode_scale: Optional[float] = None ): """ Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://arxiv.org/abs/2403.03206v1. """ if weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") u = torch.nn.functional.sigmoid(u) elif weighting_scheme == "mode": u = torch.rand(size=(batch_size,), device="cpu") u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) else: u = torch.rand(size=(batch_size,), device="cpu") return u def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas): """ Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://arxiv.org/abs/2403.03206v1. """ if weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() elif weighting_scheme == "cosmap": bot = 1 - 2 * sigmas + 2 * sigmas**2 weighting = 2 / (math.pi * bot) else: weighting = torch.ones_like(sigmas) return weighting class StableDiffusion3Base(): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16): self.device = device self.dtype = dtype pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype) self.scheduler = pipe.scheduler self.tokenizer_1 = pipe.tokenizer self.tokenizer_2 = pipe.tokenizer_2 self.tokenizer_3 = pipe.tokenizer_3 self.text_enc_1 = pipe.text_encoder.to(device) self.text_enc_2 = pipe.text_encoder_2.to(device) self.text_enc_3 = pipe.text_encoder_3.to(device) self.vae=pipe.vae.to(device) self.transformer = pipe.transformer.to(device) self.transformer.eval() self.transformer.requires_grad_(False) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8 ) del pipe def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]: ''' We assume that 1. number of tokens < max_length 2. one prompt for one image ''' # CLIP encode (used for modulation of adaLN-zero) # now, we have two CLIPs text_clip1_ids = self.tokenizer_1(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.device), output_hidden_states=True) pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.device) text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.device) text_clip2_ids = self.tokenizer_2(prompt, padding="max_length", max_length=77, truncation=True, return_tensors='pt').input_ids text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.device), output_hidden_states=True) pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.device) text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.device) # T5 encode (used for text condition) text_t5_ids = self.tokenizer_3(prompt, padding="max_length", max_length=512, truncation=True, add_special_tokens=True, return_tensors='pt').input_ids text_t5_emb = self.text_enc_3(text_t5_ids.to(self.device))[0] text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.device) # Merge clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1) clip_prompt_emb = torch.nn.functional.pad( clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1]) ) prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2) pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1) return prompt_emb, pooled_prompt_emb def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs): H, W = img_size lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor lC = self.transformer.config.in_channels latent_shape = (batch_size, lC, lH, lW) z = torch.randn(latent_shape, device=self.device, dtype=self.dtype) return z def encode(self, image: torch.Tensor) -> torch.Tensor: z = self.vae.encode(image).latent_dist.sample() z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor return z def decode(self, z: torch.Tensor) -> torch.Tensor: z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor return self.vae.decode(z, return_dict=False)[0] class SD3Euler(StableDiffusion3Base): def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'): super().__init__(model_key=model_key, device=device) def inversion(self, src_img, prompts: List[str], NFE:int, cfg_scale: float=1.0, batch_size: int=1): # encode text prompts prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) # initialize latent src_img = src_img.to(device=self.device, dtype=self.dtype) with torch.no_grad(): z = self.encode(src_img) z0 = z.clone() # timesteps (default option. You can make your custom here.) self.scheduler.set_timesteps(NFE, device=self.device) timesteps = self.scheduler.timesteps timesteps = torch.cat([timesteps, torch.zeros(1, device=self.device)]) timesteps = reversed(timesteps) sigmas = timesteps / self.scheduler.config.num_train_timesteps # Solve ODE pbar = tqdm(timesteps[:-1], total=NFE, desc='SD3 Euler Inversion') for i, t in enumerate(pbar): timestep = t.expand(z.shape[0]).to(self.device) pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) if cfg_scale != 1.0: pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) else: pred_null_v = 0.0 sigma = sigmas[i] sigma_next = sigmas[i+1] z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v)) return z def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None, cfg_scale: float=1.0, batch_size: int = 1, latent:Optional[torch.Tensor]=None): imgH, imgW = img_shape if img_shape is not None else (512, 512) # encode text prompts with torch.no_grad(): prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) # initialize latent if latent is None: z = self.initialize_latent((imgH, imgW), batch_size) else: z = latent # timesteps (default option. You can make your custom here.) self.scheduler.set_timesteps(NFE, device=self.device) timesteps = self.scheduler.timesteps sigmas = timesteps / self.scheduler.config.num_train_timesteps # Solve ODE pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler') for i, t in enumerate(pbar): timestep = t.expand(z.shape[0]).to(self.device) pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) if cfg_scale != 1.0: pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) else: pred_null_v = 0.0 sigma = sigmas[i] sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v)) # decode with torch.no_grad(): img = self.decode(z) return img class OSEDiff_SD3_GEN(torch.nn.Module): def __init__(self, args, base_model): super().__init__() self.args = args self.model = base_model # Add lora to transformer print('Adding Lora to OSEDiff_SD3_GEN') self.transformer_gen = copy.deepcopy(self.model.transformer) self.transformer_gen.to('cuda') # self.transformer_gen = self.transformer_gen.float() self.transformer_gen.requires_grad_(False) self.transformer_gen.train() self.transformer_gen, hooks = add_mp_hook(self.transformer_gen) self.hooks = hooks lora_params, _ = inject_lora(self.transformer_gen, {"AdaLayerNormZero"}, r=args.lora_rank, verbose=True) # self.lora_params = lora_params for name, param in self.transformer_gen.named_parameters(): if "lora_" in name: param.requires_grad = True # LoRA up/down else: param.requires_grad = False # everything else # Insert LoRA into VAE print("Adding Lora to VAE") self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=True) def predict_vector(self, z, t, prompt_emb, pooled_emb): v = self.transformer_gen(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v def forward(self, x_src, batch=None, args=None): z_src = self.model.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)) z_src = z_src.to(self.transformer_gen.device) # calculate prompt_embeddings and neg_prompt_embeddings batch_size, _, _, _ = x_src.shape with torch.no_grad(): prompt_embeds, pooled_embeds = self.model.encode_prompt(batch["prompt"], batch_size) neg_prompt_embeds, neg_pooled_embeds = self.model.encode_prompt(batch["neg_prompt"], batch_size) NFE = 1 self.model.scheduler.set_timesteps(NFE, device=self.model.device) timesteps = self.model.scheduler.timesteps sigmas = timesteps / self.model.scheduler.config.num_train_timesteps sigmas = sigmas.to(self.transformer_gen.device) # Solve ODE i = 0 t = timesteps[0] timestep = t.expand(z_src.shape[0]).to(self.transformer_gen.device) prompt_embeds = prompt_embeds.to(self.transformer_gen.device, dtype=torch.float32) pooled_embeds = pooled_embeds.to(self.transformer_gen.device, dtype=torch.float32) pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds) pred_null_v = 0.0 sigma = sigmas[i] sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 z_src = z_src + (sigma_next - sigma) * (pred_null_v + 1 * (pred_v - pred_null_v)) output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device)) return output_image, z_src, prompt_embeds, pooled_embeds class OSEDiff_SD3_REG(torch.nn.Module): def __init__(self, args, base_model): super().__init__() self.args = args self.model = base_model self.transformer_org = self.model.transformer # Add lora to transformer print('Adding Lora to OSEDiff_SD3_REG') self.transformer_reg = copy.deepcopy(self.transformer_org) self.transformer_reg.to('cuda') self.transformer_reg.requires_grad_(False) self.transformer_reg.train() self.transformer_reg, hooks = add_mp_hook(self.transformer_reg) self.hooks = hooks lora_params, _ = inject_lora(self.transformer_reg, {"AdaLayerNormZero"}, r=args.lora_rank, verbose=True) for name, param in self.transformer_reg.named_parameters(): if "lora_" in name: param.requires_grad = True # LoRA up/down else: param.requires_grad = False # everything else def predict_vector_reg(self, z, t, prompt_emb, pooled_emb): v = self.transformer_reg(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v def predict_vector_org(self, z, t, prompt_emb, pooled_emb): v = self.transformer_org(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v def distribution_matching_loss(self, z0, prompt_embeds, pooled_embeds, global_step, args): with torch.no_grad(): device = self.transformer_reg.device # get timesteps and sigma u = compute_density_for_timestep_sampling( weighting_scheme="uniform", batch_size=1, logit_mean=0.0, logit_std=1.0, mode_scale=1.29, ) t_idx = (u*1000).long().to(device) self.model.scheduler.set_timesteps(1000, device=device) times = self.model.scheduler.timesteps t = times[t_idx] sigma = t / 1000 # get noise and xt z0 = z0.to(device) noise = torch.randn_like(z0) sigma = sigma.half() zt = (1-sigma) * z0 + sigma * noise # Get x0_prediction of transformer_reg v_pred_reg = self.predict_vector_reg(zt, t, prompt_embeds.to(device), pooled_embeds.to(device)) reg_model_pred = v_pred_reg * (-sigma) + zt # this is x0_prediction for reg # Get x0_prediction of transformer_org org_device = self.transformer_org.device v_pred_org = self.predict_vector_org(zt.to(org_device), t.to(org_device), prompt_embeds.to(org_device), pooled_embeds.to(org_device)) org_model_pred = v_pred_org * (-sigma.to(org_device)) + zt.to(org_device) # this is x0_prediction for org # Visualization if global_step % 100 == 1: self.vsd_visualization(z0, noise, zt, reg_model_pred, org_model_pred, global_step, args) weighting_factor = torch.abs(z0 - org_model_pred.to(device)).mean(dim=[1, 2, 3], keepdim=True) grad = (reg_model_pred - org_model_pred.to(device)) / weighting_factor loss = F.mse_loss(z0, (z0 - grad).detach()) return loss def vsd_visualization(self, z0, noise, zt, reg_model_pred, org_model_pred, global_step, args): #-------- Visualization --------# # 1. Visualize latents, noise, zt z0_img = self.model.decode(z0.to(dtype=torch.float32, device=self.model.vae.device)) ns_img = self.model.decode(noise.to(dtype=torch.float32, device=self.model.vae.device)) zt_img = self.model.decode(zt.to(dtype=torch.float32, device=self.model.vae.device)) z0_img_pil = transforms.ToPILImage()(torch.clamp(z0_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) ns_img_pil = transforms.ToPILImage()(torch.clamp(ns_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) zt_img_pil = transforms.ToPILImage()(torch.clamp(zt_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) # 2. Visualize reg_img, org_img reg_img = self.model.decode(reg_model_pred.to(dtype=torch.float32, device=self.model.vae.device)) org_img = self.model.decode(org_model_pred.to(dtype=torch.float32, device=self.model.vae.device)) reg_img_pil = transforms.ToPILImage()(torch.clamp(reg_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) org_img_pil = transforms.ToPILImage()(torch.clamp(org_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) # Concatenate images side by side w, h = z0_img_pil.width, z0_img_pil.height combined_image = Image.new('RGB', (w*5, h)) combined_image.paste(z0_img_pil, (0, 0)) combined_image.paste(ns_img_pil, (w, 0)) combined_image.paste(zt_img_pil, (w*2, 0)) combined_image.paste(reg_img_pil, (w*3, 0)) combined_image.paste(org_img_pil, (w*4, 0)) combined_image.save(os.path.join(args.output_dir, f'visualization/vsd/{global_step}.png')) #-------- Visualization --------# def diff_loss(self, z0, prompt_embeds, pooled_embeds, net_lpips, args): device = self.transformer_reg.device u = compute_density_for_timestep_sampling( weighting_scheme="uniform", batch_size=1, logit_mean=0.0, logit_std=1.0, mode_scale=1.29, ) t_idx = (u*1000).long().to(device) self.model.scheduler.set_timesteps(1000, device=device) times = self.model.scheduler.timesteps t = times[t_idx] sigma = t / 1000 z0 = z0.to(device) z0, prompt_embeds = z0.detach(), prompt_embeds.detach() noise = torch.randn_like(z0) sigma = sigma.half() zt = (1-sigma) * z0 + sigma * noise # noisy latents # v-prediction v_pred = self.predict_vector_reg(zt, t, prompt_embeds.to(device), pooled_embeds.to(device)) model_pred = v_pred * (-sigma) + zt target = z0 loss_weight = compute_loss_weighting_for_sd3("logit_normal", sigma) diffusion_loss = loss_weight.float() * F.mse_loss(model_pred.float(), target.float()) loss_d = diffusion_loss return loss_d.mean() class OSEDiff_SD3_TEST(torch.nn.Module): def __init__(self, args, base_model): super().__init__() self.args = args self.model = base_model self.lora_path = args.lora_path self.vae_path = args.vae_path # Add lora to transformer print(f'Loading LoRA to Transformer from {self.lora_path}') self.model.transformer.requires_grad_(False) lora_params, _ = inject_lora(self.model.transformer, {"AdaLayerNormZero"}, loras=self.lora_path, r=args.lora_rank, verbose=False) for name, param in self.model.transformer.named_parameters(): param.requires_grad = False # Insert LoRA into VAE print(f"Loading LoRA to VAE from {self.vae_path}") self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=False) encoder_state_dict_fp16 = torch.load(self.vae_path, map_location="cpu") self.model.vae.encoder.load_state_dict(encoder_state_dict_fp16) def predict_vector(self, z, t, prompt_emb, pooled_emb): v = self.model.transformer(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v @torch.no_grad() def forward(self, x_src, prompt): z_src = self.model.vae.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)).latent_dist.sample() * self.model.vae.config.scaling_factor z_src = z_src.to(self.model.transformer.device) # calculate prompt_embeddings and neg_prompt_embeddings batch_size, _, _, _ = x_src.shape with torch.no_grad(): prompt_embeds, pooled_embeds = self.model.encode_prompt([prompt], batch_size) self.model.scheduler.set_timesteps(1, device=self.model.device) timesteps = self.model.scheduler.timesteps # Solve ODE t = timesteps[0] timestep = t.expand(z_src.shape[0]).to(self.model.transformer.device) prompt_embeds = prompt_embeds.to(self.model.transformer.device, dtype=torch.float32) pooled_embeds = pooled_embeds.to(self.model.transformer.device, dtype=torch.float32) pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds) z_src = z_src - pred_v with torch.no_grad(): output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device)) return output_image class OSEDiff_SD3_TEST_efficient(torch.nn.Module): def __init__(self, args, base_model): super().__init__() self.args = args self.model = base_model self.lora_path = args.lora_path self.vae_path = args.vae_path # Add lora to transformer print(f'Loading LoRA to Transformer from {self.lora_path}') self.model.transformer.requires_grad_(False) lora_params, _ = inject_lora(self.model.transformer, {"AdaLayerNormZero"}, loras=self.lora_path, r=args.lora_rank, verbose=False) for name, param in self.model.transformer.named_parameters(): param.requires_grad = False # Insert LoRA into VAE print(f"Loading LoRA to VAE from {self.vae_path}") self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=False) encoder_state_dict_fp16 = torch.load(self.vae_path, map_location="cpu") self.model.vae.encoder.load_state_dict(encoder_state_dict_fp16) def predict_vector(self, z, t, prompt_emb, pooled_emb): v = self.model.transformer(hidden_states=z, timestep=t, pooled_projections=pooled_emb, encoder_hidden_states=prompt_emb, return_dict=False)[0] return v @torch.no_grad() def forward(self, x_src, prompt): z_src = self.model.vae.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)).latent_dist.sample() * self.model.vae.config.scaling_factor z_src = z_src.to(self.model.transformer.device) # calculate prompt_embeddings batch_size, _, _, _ = x_src.shape prompt_embeds, pooled_embeds = self.model.encode_prompt([prompt], batch_size) self.model.scheduler.set_timesteps(1, device=self.model.device) timesteps = self.model.scheduler.timesteps # Solve ODE t = timesteps[0] timestep = t.expand(z_src.shape[0]).to(self.model.transformer.device) prompt_embeds = prompt_embeds.to(self.model.transformer.device, dtype=torch.float32) pooled_embeds = pooled_embeds.to(self.model.transformer.device, dtype=torch.float32) pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds) z_src = z_src - pred_v output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device)) return output_image