File size: 3,330 Bytes
cdcfdd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import imageio
import os
import argparse
from diffusers.schedulers import EulerAncestralDiscreteScheduler
from transformers import T5EncoderModel, T5Tokenizer
from allegro.pipelines.pipeline_allegro import AllegroPipeline
from allegro.models.vae.vae_allegro import AllegroAutoencoderKL3D
from allegro.models.transformers.transformer_3d_allegro import AllegroTransformer3DModel


def single_inference(args):
    dtype=torch.bfloat16

    # vae have better formance in float32
    vae = AllegroAutoencoderKL3D.from_pretrained(args.vae, torch_dtype=torch.float32).cuda()

    vae.eval()

    text_encoder = T5EncoderModel.from_pretrained(
        args.text_encoder, 
        torch_dtype=dtype
    )
    text_encoder.eval()

    tokenizer = T5Tokenizer.from_pretrained(
        args.tokenizer,
    )

    scheduler = EulerAncestralDiscreteScheduler()

    transformer = AllegroTransformer3DModel.from_pretrained(
        args.dit,
        torch_dtype=dtype
    ).cuda()
    transformer.eval()

    allegro_pipeline = AllegroPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        scheduler=scheduler,
        transformer=transformer
    ).to("cuda:0")


    positive_prompt = """
(masterpiece), (best quality), (ultra-detailed), (unwatermarked), 
{} 
emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, 
sharp focus, high budget, cinemascope, moody, epic, gorgeous
"""

    negative_prompt = """
nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, 
low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.
"""

    user_prompt = positive_prompt.format(args.user_prompt.lower().strip())

    if args.enable_cpu_offload:
        allegro_pipeline.enable_sequential_cpu_offload()
        print("cpu offload enabled")
        
    out_video = allegro_pipeline(
        user_prompt, 
        negative_prompt = negative_prompt, 
        num_frames=88,
        height=720,
        width=1280,
        num_inference_steps=args.num_sampling_steps,
        guidance_scale=args.guidance_scale,
        max_sequence_length=512,
        generator = torch.Generator(device="cuda:0").manual_seed(args.seed)
    ).video[0]

    imageio.mimwrite(args.save_path, out_video, fps=15, quality=8)  # highest quality is 10, lowest is 0


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--user_prompt", type=str, default='')
    parser.add_argument("--vae", type=str, default='')
    parser.add_argument("--dit", type=str, default='')
    parser.add_argument("--text_encoder", type=str, default='')
    parser.add_argument("--tokenizer", type=str, default='')
    parser.add_argument("--save_path", type=str, default="./output_videos/test_video.mp4")
    parser.add_argument("--guidance_scale", type=float, default=7.5)
    parser.add_argument("--num_sampling_steps", type=int, default=100)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--enable_cpu_offload", action='store_true')

    args = parser.parse_args()

    if os.path.dirname(args.save_path) != '' and (not os.path.exists(os.path.dirname(args.save_path))):
        os.makedirs(os.path.dirname(args.save_path))
    
    single_inference(args)