schroneko's picture
enable to change prompt
52ce992
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXTransformer3DModel,
)
from diffusers.utils import export_to_video
import tqdm
from torchvision.transforms import ToPILImage
import os
import spaces
#from torchao.quantization import autoquant
device="cuda"
shape=(1,48//4,16,256//8,256//8)
sample_N=25
torch_dtype=torch.bfloat16
eps=1
cfg=2.5
tokenizer = AutoTokenizer.from_pretrained(
"llm-jp/llm-jp-3-1.8b"
)
text_encoder = AutoModelForCausalLM.from_pretrained(
"llm-jp/llm-jp-3-1.8b",
torch_dtype=torch_dtype
)
text_encoder=text_encoder.to(device)
transformer = CogVideoXTransformer3DModel.from_pretrained(
"aidealab/AIdeaLab-VideoJP",
torch_dtype=torch_dtype,
token=os.environ['TOKEN']
)
#transformer = autoquant(transformer, error_on_unseen=False)
#transformer.to(memory_format=torch.channels_last)
#transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
transformer=transformer.to(device)
vae = AutoencoderKLCogVideoX.from_pretrained(
"THUDM/CogVideoX-2b",
subfolder="vae"
)
vae=vae.to(dtype=torch_dtype, device=device)
vae.enable_slicing()
vae.enable_tiling()
@spaces.GPU
def text_to_video(prompt, cfg=cfg):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=512,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True, attention_mask=text_inputs.attention_mask.to(device)).hidden_states[-1]
prompt_embeds = prompt_embeds.to(dtype=torch_dtype, device=device)
null_text_inputs = tokenizer(
"",
padding="max_length",
max_length=512,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
null_text_input_ids = null_text_inputs.input_ids
null_prompt_embeds = text_encoder(null_text_input_ids.to(device), output_hidden_states=True, attention_mask=null_text_inputs.attention_mask.to(device)).hidden_states[-1]
null_prompt_embeds = null_prompt_embeds.to(dtype=torch_dtype, device=device)
# euler discreate sampler with cfg
z0 = torch.randn(shape, device=device)
latents = z0.detach().clone().to(torch_dtype)
dt = 1.0 / sample_N
with torch.no_grad():
for i in tqdm.tqdm(range(sample_N)):
num_t = i / sample_N
t = torch.ones(shape[0], device=device) * num_t
psudo_t=(1000-eps)*(1-t)+eps
positive_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=prompt_embeds, image_rotary_emb=None)
null_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=null_prompt_embeds, image_rotary_emb=None)
pred = null_conditional.sample+cfg*(positive_conditional.sample-null_conditional.sample)
latents = latents.detach().clone() + dt * pred.detach().clone()
# Free vram
latents = latents / vae.config.scaling_factor
latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
x=vae.decode(latents).sample
x = x / 2 + 0.5
x = x.clamp(0,1)
x=x.permute(0, 2, 1, 3, 4).to(torch.float32)# [B, F, C, H, W]
print(x.shape)
x=[ToPILImage()(frame) for frame in x[0]]
export_to_video(x,"output.mp4",fps=24)
return "output.mp4"
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
# Gradio アプリケーションのレイアウトを定義
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""# AIdeaLab VideoJP Demo
AIdeaLab VideoJPは、Rectified Flow Transformerで作られている軽量な動画生成モデルです ([詳細](https://note.com/aidealab/n/n677018ea1953)、[モデル](https://huggingface.co/aidealab/AIdeaLab-VideoJP))。十数秒で動画を作ることができます。なお、AIdeaLab VideoJPは経済産業省と国立研究開発法人新エネルギー・産業技術総合開発機構(NEDO)が実施する、国内の生成AIの開発力強化を目的としたプロジェクト「GENIAC(Generative AI Accelerator Challenge)」の成果をもとに作成されました。""")
# テキストボックスで自由入力
text_input = gr.Textbox(
label="動画生成のプロンプトを入力してください",
placeholder="例: 静かな森の中を、やわらかな朝陽が差し込む。木漏れ日に照らされた小川には小さな魚が泳ぎ、森の奥からは小鳥のさえずりが聞こえる。",
lines=5
)
generate_button = gr.Button("生成")
output_video = gr.Video(label="生成された動画")
# ボタンクリック時の挙動を設定
generate_button.click(
fn=text_to_video,
inputs=text_input,
outputs=output_video
)
demo.launch()