Spaces:
Runtime error
Runtime error
import argparse | |
import sys | |
import os | |
import random | |
import imageio | |
import torch | |
from diffusers import PNDMScheduler | |
from huggingface_hub import hf_hub_download | |
from torchvision.utils import save_image | |
from diffusers.models import AutoencoderKL | |
from datetime import datetime | |
from typing import List, Union | |
import gradio as gr | |
import numpy as np | |
from gradio.components import Textbox, Video, Image | |
from transformers import T5Tokenizer, T5EncoderModel | |
from opensora.models.ae import ae_stride_config, getae, getae_wrapper | |
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper | |
from opensora.models.diffusion.latte.modeling_latte import LatteT2V | |
from opensora.sample.pipeline_videogen import VideoGenPipeline | |
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION | |
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False): | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
set_env(seed) | |
video_length = transformer_model.config.video_length if not force_images else 1 | |
height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2]) | |
num_frames = 1 if video_length == 1 else int(args.version.split('x')[0]) | |
videos = videogen_pipeline(prompt, | |
video_length=video_length, | |
height=height, | |
width=width, | |
num_inference_steps=sample_steps, | |
guidance_scale=scale, | |
enable_temporal_attentions=not force_images, | |
num_images_per_prompt=1, | |
mask_feature=True, | |
).video | |
torch.cuda.empty_cache() | |
videos = videos[0] | |
tmp_save_path = 'tmp.mp4' | |
imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0 | |
display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}" | |
return tmp_save_path, prompt, display_model_info, seed | |
if __name__ == '__main__': | |
args = type('args', (), { | |
'ae': 'CausalVAEModel_4x8x8', | |
'force_images': False, | |
'model_path': 'LanguageBind/Open-Sora-Plan-v1.0.0', | |
'text_encoder_name': 'DeepFloyd/t5-v1_1-xxl', | |
'version': '65x512x512' | |
}) | |
device = torch.device('cuda:0') | |
# Load model: | |
transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16, cache_dir='cache_dir').to(device) | |
vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16) | |
vae.vae.enable_tiling() | |
image_size = int(args.version.split('x')[1]) | |
latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) | |
vae.latent_size = latent_size | |
transformer_model.force_images = args.force_images | |
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir") | |
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", | |
torch_dtype=torch.float16).to(device) | |
# set eval mode | |
transformer_model.eval() | |
vae.eval() | |
text_encoder.eval() | |
scheduler = PNDMScheduler() | |
videogen_pipeline = VideoGenPipeline(vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
transformer=transformer_model).to(device=device) | |
demo = gr.Interface( | |
fn=generate_img, | |
inputs=[Textbox(label="", | |
placeholder="Please enter your prompt. \n"), | |
gr.Slider( | |
label='Sample Steps', | |
minimum=1, | |
maximum=500, | |
value=50, | |
step=10 | |
), | |
gr.Slider( | |
label='Guidance Scale', | |
minimum=0.1, | |
maximum=30.0, | |
value=10.0, | |
step=0.1 | |
), | |
gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=203279, | |
step=1, | |
value=0, | |
), | |
gr.Checkbox(label="Randomize seed", value=True), | |
gr.Checkbox(label="Generate image (1 frame video)", value=False), | |
], | |
outputs=[Video(label="Vid", width=512, height=512), | |
Textbox(label="input prompt"), | |
Textbox(label="model info"), | |
gr.Slider(label='seed')], | |
title=title_markdown, description=DESCRIPTION, theme=gr.themes.Default(), css=block_css, | |
examples=examples, | |
) | |
demo.launch() |