Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from diffusers import ( | |
AutoencoderKLCogVideoX, | |
CogVideoXTransformer3DModel, | |
) | |
from diffusers.utils import export_to_video | |
import tqdm | |
from torchvision.transforms import ToPILImage | |
import os | |
import spaces | |
#from torchao.quantization import autoquant | |
device="cuda" | |
shape=(1,48//4,16,256//8,256//8) | |
sample_N=25 | |
torch_dtype=torch.bfloat16 | |
eps=1 | |
cfg=2.5 | |
tokenizer = AutoTokenizer.from_pretrained( | |
"llm-jp/llm-jp-3-1.8b" | |
) | |
text_encoder = AutoModelForCausalLM.from_pretrained( | |
"llm-jp/llm-jp-3-1.8b", | |
torch_dtype=torch_dtype | |
) | |
text_encoder=text_encoder.to(device) | |
transformer = CogVideoXTransformer3DModel.from_pretrained( | |
"aidealab/AIdeaLab-VideoJP", | |
torch_dtype=torch_dtype, | |
token=os.environ['TOKEN'] | |
) | |
#transformer = autoquant(transformer, error_on_unseen=False) | |
#transformer.to(memory_format=torch.channels_last) | |
#transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) | |
transformer=transformer.to(device) | |
vae = AutoencoderKLCogVideoX.from_pretrained( | |
"THUDM/CogVideoX-2b", | |
subfolder="vae" | |
) | |
vae=vae.to(dtype=torch_dtype, device=device) | |
vae.enable_slicing() | |
vae.enable_tiling() | |
def text_to_video(prompt, cfg=cfg): | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=512, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True, attention_mask=text_inputs.attention_mask.to(device)).hidden_states[-1] | |
prompt_embeds = prompt_embeds.to(dtype=torch_dtype, device=device) | |
null_text_inputs = tokenizer( | |
"", | |
padding="max_length", | |
max_length=512, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
null_text_input_ids = null_text_inputs.input_ids | |
null_prompt_embeds = text_encoder(null_text_input_ids.to(device), output_hidden_states=True, attention_mask=null_text_inputs.attention_mask.to(device)).hidden_states[-1] | |
null_prompt_embeds = null_prompt_embeds.to(dtype=torch_dtype, device=device) | |
# euler discreate sampler with cfg | |
z0 = torch.randn(shape, device=device) | |
latents = z0.detach().clone().to(torch_dtype) | |
dt = 1.0 / sample_N | |
with torch.no_grad(): | |
for i in tqdm.tqdm(range(sample_N)): | |
num_t = i / sample_N | |
t = torch.ones(shape[0], device=device) * num_t | |
psudo_t=(1000-eps)*(1-t)+eps | |
positive_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=prompt_embeds, image_rotary_emb=None) | |
null_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=null_prompt_embeds, image_rotary_emb=None) | |
pred = null_conditional.sample+cfg*(positive_conditional.sample-null_conditional.sample) | |
latents = latents.detach().clone() + dt * pred.detach().clone() | |
# Free vram | |
latents = latents / vae.config.scaling_factor | |
latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] | |
x=vae.decode(latents).sample | |
x = x / 2 + 0.5 | |
x = x.clamp(0,1) | |
x=x.permute(0, 2, 1, 3, 4).to(torch.float32)# [B, F, C, H, W] | |
print(x.shape) | |
x=[ToPILImage()(frame) for frame in x[0]] | |
export_to_video(x,"output.mp4",fps=24) | |
return "output.mp4" | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
# Gradio アプリケーションのレイアウトを定義 | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("""# AIdeaLab VideoJP Demo | |
AIdeaLab VideoJPは、Rectified Flow Transformerで作られている軽量な動画生成モデルです ([詳細](https://note.com/aidealab/n/n677018ea1953)、[モデル](https://huggingface.co/aidealab/AIdeaLab-VideoJP))。十数秒で動画を作ることができます。なお、AIdeaLab VideoJPは経済産業省と国立研究開発法人新エネルギー・産業技術総合開発機構(NEDO)が実施する、国内の生成AIの開発力強化を目的としたプロジェクト「GENIAC(Generative AI Accelerator Challenge)」の成果をもとに作成されました。""") | |
# テキストボックスで自由入力 | |
text_input = gr.Textbox( | |
label="動画生成のプロンプトを入力してください", | |
placeholder="例: 静かな森の中を、やわらかな朝陽が差し込む。木漏れ日に照らされた小川には小さな魚が泳ぎ、森の奥からは小鳥のさえずりが聞こえる。", | |
lines=5 | |
) | |
generate_button = gr.Button("生成") | |
output_video = gr.Video(label="生成された動画") | |
# ボタンクリック時の挙動を設定 | |
generate_button.click( | |
fn=text_to_video, | |
inputs=text_input, | |
outputs=output_video | |
) | |
demo.launch() | |