File size: 5,143 Bytes
57f0a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c0741
57f0a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4365534
f62b8e3
52ce992
 
 
 
 
 
57f0a7b
 
 
 
 
 
 
 
 
 
 
52ce992
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from diffusers import (
    AutoencoderKLCogVideoX,
    CogVideoXTransformer3DModel,
)
from diffusers.utils import export_to_video
import tqdm
from torchvision.transforms import ToPILImage
import os
import spaces

#from torchao.quantization import autoquant

device="cuda"
shape=(1,48//4,16,256//8,256//8)
sample_N=25
torch_dtype=torch.bfloat16
eps=1
cfg=2.5

tokenizer = AutoTokenizer.from_pretrained(
    "llm-jp/llm-jp-3-1.8b"
)

text_encoder = AutoModelForCausalLM.from_pretrained(
    "llm-jp/llm-jp-3-1.8b",
    torch_dtype=torch_dtype
)
text_encoder=text_encoder.to(device)

transformer = CogVideoXTransformer3DModel.from_pretrained(
    "aidealab/AIdeaLab-VideoJP",
    torch_dtype=torch_dtype,
    token=os.environ['TOKEN']
)
#transformer = autoquant(transformer, error_on_unseen=False)
#transformer.to(memory_format=torch.channels_last)
#transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
transformer=transformer.to(device)

vae = AutoencoderKLCogVideoX.from_pretrained(
    "THUDM/CogVideoX-2b",
    subfolder="vae"
)
vae=vae.to(dtype=torch_dtype, device=device)
vae.enable_slicing()
vae.enable_tiling()

@spaces.GPU
def text_to_video(prompt, cfg=cfg):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=512,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True, attention_mask=text_inputs.attention_mask.to(device)).hidden_states[-1]
    prompt_embeds = prompt_embeds.to(dtype=torch_dtype, device=device)

    null_text_inputs = tokenizer(
        "",
        padding="max_length",
        max_length=512,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    null_text_input_ids = null_text_inputs.input_ids
    null_prompt_embeds = text_encoder(null_text_input_ids.to(device), output_hidden_states=True, attention_mask=null_text_inputs.attention_mask.to(device)).hidden_states[-1]
    null_prompt_embeds = null_prompt_embeds.to(dtype=torch_dtype, device=device)

    # euler discreate sampler with cfg
    z0 = torch.randn(shape, device=device)
    latents = z0.detach().clone().to(torch_dtype)
    
    dt = 1.0 / sample_N
    with torch.no_grad():
        for i in tqdm.tqdm(range(sample_N)):
            num_t = i / sample_N
            t = torch.ones(shape[0], device=device) * num_t
            psudo_t=(1000-eps)*(1-t)+eps
            positive_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=prompt_embeds, image_rotary_emb=None)
            null_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=null_prompt_embeds, image_rotary_emb=None)
            pred = null_conditional.sample+cfg*(positive_conditional.sample-null_conditional.sample)
            latents = latents.detach().clone() + dt * pred.detach().clone()
    
        # Free vram
        latents = latents / vae.config.scaling_factor
        latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
        x=vae.decode(latents).sample
        x = x / 2 + 0.5
        x = x.clamp(0,1)
        x=x.permute(0, 2, 1, 3, 4).to(torch.float32)# [B, F, C, H, W]
        print(x.shape)
        x=[ToPILImage()(frame) for frame in x[0]]
    
    export_to_video(x,"output.mp4",fps=24)
    return "output.mp4"

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

# Gradio アプリケーションのレイアウトを定義
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""# AIdeaLab VideoJP Demo
AIdeaLab VideoJPは、Rectified Flow Transformerで作られている軽量な動画生成モデルです ([詳細](https://note.com/aidealab/n/n677018ea1953)、[モデル](https://huggingface.co/aidealab/AIdeaLab-VideoJP))。十数秒で動画を作ることができます。なお、AIdeaLab VideoJPは経済産業省と国立研究開発法人新エネルギー・産業技術総合開発機構(NEDO)が実施する、国内の生成AIの開発力強化を目的としたプロジェクト「GENIAC(Generative AI Accelerator Challenge)」の成果をもとに作成されました。""")        
        # テキストボックスで自由入力
        text_input = gr.Textbox(
            label="動画生成のプロンプトを入力してください",
            placeholder="例: 静かな森の中を、やわらかな朝陽が差し込む。木漏れ日に照らされた小川には小さな魚が泳ぎ、森の奥からは小鳥のさえずりが聞こえる。",
            lines=5
        )
        generate_button = gr.Button("生成")

        output_video = gr.Video(label="生成された動画")

        # ボタンクリック時の挙動を設定
        generate_button.click(
            fn=text_to_video,
            inputs=text_input,
            outputs=output_video
        )

demo.launch()