import torch import imageio import os import argparse from diffusers.schedulers import EulerAncestralDiscreteScheduler from transformers import T5EncoderModel, T5Tokenizer from allegro.pipelines.pipeline_allegro import AllegroPipeline from allegro.models.vae.vae_allegro import AllegroAutoencoderKL3D from allegro.models.transformers.transformer_3d_allegro import AllegroTransformer3DModel def single_inference(args): dtype=torch.bfloat16 # vae have better formance in float32 vae = AllegroAutoencoderKL3D.from_pretrained(args.vae, torch_dtype=torch.float32).cuda() vae.eval() text_encoder = T5EncoderModel.from_pretrained( args.text_encoder, torch_dtype=dtype ) text_encoder.eval() tokenizer = T5Tokenizer.from_pretrained( args.tokenizer, ) scheduler = EulerAncestralDiscreteScheduler() transformer = AllegroTransformer3DModel.from_pretrained( args.dit, torch_dtype=dtype ).cuda() transformer.eval() allegro_pipeline = AllegroPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer ).to("cuda:0") positive_prompt = """ (masterpiece), (best quality), (ultra-detailed), (unwatermarked), {} emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous """ negative_prompt = """ nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. """ user_prompt = positive_prompt.format(args.user_prompt.lower().strip()) if args.enable_cpu_offload: allegro_pipeline.enable_sequential_cpu_offload() print("cpu offload enabled") out_video = allegro_pipeline( user_prompt, negative_prompt = negative_prompt, num_frames=88, height=720, width=1280, num_inference_steps=args.num_sampling_steps, guidance_scale=args.guidance_scale, max_sequence_length=512, generator = torch.Generator(device="cuda:0").manual_seed(args.seed) ).video[0] imageio.mimwrite(args.save_path, out_video, fps=15, quality=8) # highest quality is 10, lowest is 0 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--user_prompt", type=str, default='') parser.add_argument("--vae", type=str, default='') parser.add_argument("--dit", type=str, default='') parser.add_argument("--text_encoder", type=str, default='') parser.add_argument("--tokenizer", type=str, default='') parser.add_argument("--save_path", type=str, default="./output_videos/test_video.mp4") parser.add_argument("--guidance_scale", type=float, default=7.5) parser.add_argument("--num_sampling_steps", type=int, default=100) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--enable_cpu_offload", action='store_true') args = parser.parse_args() if os.path.dirname(args.save_path) != '' and (not os.path.exists(os.path.dirname(args.save_path))): os.makedirs(os.path.dirname(args.save_path)) single_inference(args)