import os import torch import torch.distributed as dist from mmengine.runner import set_random_seed from videogen_hub.pipelines.opensora.opensora.acceleration.parallel_states import set_sequence_parallel_group from videogen_hub.pipelines.opensora.opensora.datasets import IMG_FPS, save_sample from videogen_hub.pipelines.opensora.opensora.models.text_encoder.t5 import text_preprocessing from videogen_hub.pipelines.opensora.opensora.registry import MODELS, SCHEDULERS, build_module from videogen_hub.pipelines.opensora.opensora.utils.config_utils import parse_configs from videogen_hub.pipelines.opensora.opensora.utils.misc import to_torch_dtype try: import colossalai from colossalai.cluster import DistCoordinator except ImportError: colossalai = None def main(config=None): # ====================================================== # 1. cfg and init distributed env # ====================================================== cfg = config if cfg is None: cfg = parse_configs(training=False) print(cfg) # init distributed if os.environ.get("WORLD_SIZE", None) and colossalai is not None: use_dist = True colossalai.launch_from_torch({}) coordinator = DistCoordinator() if coordinator.world_size > 1: set_sequence_parallel_group(dist.group.WORLD) enable_sequence_parallelism = True else: enable_sequence_parallelism = False else: use_dist = False enable_sequence_parallelism = False # ====================================================== # 2. runtime variables # ====================================================== torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device = "cuda" if torch.cuda.is_available() else "cpu" dtype = to_torch_dtype(cfg.dtype) set_random_seed(seed=cfg.seed) prompts = cfg.prompt # ====================================================== # 3. build model & load weights # ====================================================== # 3.1. build model input_size = (cfg.num_frames, *cfg.image_size) vae = build_module(cfg.vae, MODELS) latent_size = vae.get_latent_size(input_size) text_encoder = build_module( cfg.text_encoder, MODELS, device=device ) # T5 must be fp32 model = build_module( cfg.model, MODELS, input_size=latent_size, in_channels=vae.out_channels, caption_channels=text_encoder.output_dim, model_max_length=text_encoder.model_max_length, dtype=dtype, enable_sequence_parallelism=enable_sequence_parallelism, ) text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance # 3.2. move to device & eval vae = vae.to(device, dtype).eval() model = model.to(device, dtype).eval() # 3.3. build scheduler scheduler = build_module(cfg.scheduler, SCHEDULERS) # 3.4. support for multi-resolution model_args = dict() if cfg.multi_resolution == "PixArtMS": image_size = cfg.image_size hw = torch.tensor([image_size], device=device, dtype=dtype).repeat( cfg.batch_size, 1 ) ar = torch.tensor( [[image_size[0] / image_size[1]]], device=device, dtype=dtype ).repeat(cfg.batch_size, 1) model_args["data_info"] = dict(ar=ar, hw=hw) elif cfg.multi_resolution == "STDiT2": image_size = cfg.image_size height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat( cfg.batch_size ) width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat( cfg.batch_size ) num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat( cfg.batch_size ) ar = torch.tensor( [image_size[0] / image_size[1]], device=device, dtype=dtype ).repeat(cfg.batch_size) if cfg.num_frames == 1: cfg.fps = IMG_FPS fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size) model_args["height"] = height model_args["width"] = width model_args["num_frames"] = num_frames model_args["ar"] = ar model_args["fps"] = fps # ====================================================== # 4. inference # ====================================================== sample_idx = 0 if cfg.sample_name is not None: sample_name = cfg.sample_name elif cfg.prompt_as_path: sample_name = "" else: sample_name = "sample" save_dir = cfg.save_dir os.makedirs(save_dir, exist_ok=True) all_batch_samples = [] # 4.1. batch generation for i in range(0, len(prompts), cfg.batch_size): # 4.2 sample in hidden space batch_prompts_raw = prompts[i: i + cfg.batch_size] batch_prompts = [text_preprocessing(prompt) for prompt in batch_prompts_raw] # handle the last batch if len(batch_prompts_raw) < cfg.batch_size and cfg.multi_resolution == "STDiT2": model_args["height"] = model_args["height"][: len(batch_prompts_raw)] model_args["width"] = model_args["width"][: len(batch_prompts_raw)] model_args["num_frames"] = model_args["num_frames"][ : len(batch_prompts_raw) ] model_args["ar"] = model_args["ar"][: len(batch_prompts_raw)] model_args["fps"] = model_args["fps"][: len(batch_prompts_raw)] all_samples = [] # 4.3. diffusion sampling old_sample_idx = sample_idx # generate multiple samples for each prompt for k in range(cfg.num_sample): sample_idx = old_sample_idx # Skip if the sample already exists # This is useful for resuming sampling VBench if cfg.prompt_as_path: skip = True for batch_prompt in batch_prompts_raw: path = os.path.join(save_dir, f"{sample_name}{batch_prompt}") if cfg.num_sample != 1: path = f"{path}-{k}" path = f"{path}.mp4" if not os.path.exists(path): skip = False break if skip: continue # sampling z = torch.randn( len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype, ) samples = scheduler.sample( model, text_encoder, z=z, prompts=batch_prompts, device=device, additional_args=model_args, ) samples = vae.decode(samples.to(dtype), model_args["num_frames"]) # 4.4. save samples if not use_dist or coordinator.is_master(): for idx, sample in enumerate(samples): print(f"Prompt: {batch_prompts_raw[idx]}") if cfg.prompt_as_path: sample_name_suffix = batch_prompts_raw[idx] else: sample_name_suffix = f"_{sample_idx}" save_path = os.path.join( save_dir, f"{sample_name}{sample_name_suffix}" ) if cfg.num_sample != 1: save_path = f"{save_path}-{k}" # save_sample( # sample, fps=cfg.fps, save_path=save_path # ) sample_idx += 1 all_samples.append(samples) all_batch_samples.append(all_samples) return all_batch_samples if __name__ == "__main__": main()