from ..models.hunyuan_dit import HunyuanDiT from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder from ..models.sdxl_vae_encoder import SDXLVAEEncoder from ..models.sdxl_vae_decoder import SDXLVAEDecoder from ..models import ModelManager from ..prompters import HunyuanDiTPrompter from ..schedulers import EnhancedDDIMScheduler from .base import BasePipeline import torch from tqdm import tqdm import numpy as np class ImageSizeManager: def __init__(self): pass def _to_tuple(self, x): if isinstance(x, int): return x, x else: return x def get_fill_resize_and_crop(self, src, tgt): th, tw = self._to_tuple(tgt) h, w = self._to_tuple(src) tr = th / tw # base 分辨率 r = h / w # 目标分辨率 # resize if r > tr: resize_height = th resize_width = int(round(th / h * w)) else: resize_width = tw resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 crop_top = int(round((th - resize_height) / 2.0)) crop_left = int(round((tw - resize_width) / 2.0)) return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) def get_meshgrid(self, start, *args): if len(args) == 0: # start is grid_size num = self._to_tuple(start) start = (0, 0) stop = num elif len(args) == 1: # start is start, args[0] is stop, step is 1 start = self._to_tuple(start) stop = self._to_tuple(args[0]) num = (stop[0] - start[0], stop[1] - start[1]) elif len(args) == 2: # start is start, args[0] is stop, args[1] is num start = self._to_tuple(start) # 左上角 eg: 12,0 stop = self._to_tuple(args[0]) # 右下角 eg: 20,32 num = self._to_tuple(args[1]) # 目标大小 eg: 32,124 else: raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) # [2, W, H] return grid def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True): grid = self.get_meshgrid(start, *args) # [2, H, w] grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) return pos_embed def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False): assert embed_dim % 4 == 0 # use half of dimensions to encode grid_h emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) if use_real: cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) return cos, sin else: emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) return emb def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False): if isinstance(pos, int): pos = np.arange(pos) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] if use_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis def calc_rope(self, height, width): patch_size = 2 head_size = 88 th = height // 8 // patch_size tw = width // 8 // patch_size base_size = 512 // 8 // patch_size start, stop = self.get_fill_resize_and_crop((th, tw), base_size) sub_args = [start, stop, (th, tw)] rope = self.get_2d_rotary_pos_embed(head_size, *sub_args) return rope class HunyuanDiTImagePipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16): super().__init__(device=device, torch_dtype=torch_dtype) self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03) self.prompter = HunyuanDiTPrompter() self.image_size_manager = ImageSizeManager() # models self.text_encoder: HunyuanDiTCLIPTextEncoder = None self.text_encoder_t5: HunyuanDiTT5TextEncoder = None self.dit: HunyuanDiT = None self.vae_decoder: SDXLVAEDecoder = None self.vae_encoder: SDXLVAEEncoder = None def denoising_model(self): return self.dit def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]): # Main models self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder") self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder") self.dit = model_manager.fetch_model("hunyuan_dit") self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder") self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder") self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5) self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) @staticmethod def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]): pipe = HunyuanDiTImagePipeline( device=model_manager.device, torch_dtype=model_manager.torch_dtype, ) pipe.fetch_models(model_manager, prompt_refiner_classes) return pipe def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32): latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return latents def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.vae_output_to_image(image) return image def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True): text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt( prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=positive, device=self.device ) return { "text_emb": text_emb, "text_emb_mask": text_emb_mask, "text_emb_t5": text_emb_t5, "text_emb_mask_t5": text_emb_mask_t5 } def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32): batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8 if tiled: height, width = tile_size * 16, tile_size * 16 image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device) freqs_cis_img = self.image_size_manager.calc_rope(height, width) image_meta_size = torch.stack([image_meta_size] * batch_size) return { "size_emb": image_meta_size, "freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)), "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride } @torch.no_grad() def __call__( self, prompt, local_prompts=[], masks=[], mask_scales=[], negative_prompt="", cfg_scale=7.5, clip_skip=1, clip_skip_2=1, input_image=None, reference_strengths=[0.4], denoising_strength=1.0, height=1024, width=1024, num_inference_steps=20, tiled=False, tile_size=64, tile_stride=32, progress_bar_cmd=tqdm, progress_bar_st=None, ): # Prepare scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) if input_image is not None: image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32) latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = noise.clone() # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts] # Prepare positional id extra_input = self.prepare_extra_input(latents, tiled, tile_size) # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device) # Positive side inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input) noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) if cfg_scale != 1.0: # Negative side noise_pred_nega = self.dit( latents, timestep=timestep, **prompt_emb_nega, **extra_input, ) # Classifier-free guidance noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) if progress_bar_st is not None: progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return image