|
import json |
|
import os |
|
import tempfile |
|
import time |
|
|
|
import click |
|
import numpy as np |
|
|
|
from einops import rearrange |
|
from PIL import Image |
|
from tqdm import tqdm |
|
|
|
from mochi_preview.t2v_synth_mochi import T2VSynthMochiModel |
|
|
|
model = None |
|
model_path = "weights" |
|
def noexcept(f): |
|
try: |
|
return f() |
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_model_path(path): |
|
global model_path |
|
model_path = path |
|
|
|
|
|
def load_model(): |
|
global model, model_path |
|
if model is None: |
|
|
|
MOCHI_DIR = model_path |
|
VAE_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_vae_bf16.safetensors" |
|
MODEL_CONFIG_PATH = f"{MOCHI_DIR}/dit-config.yaml" |
|
MODEL_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_dit_fp8_e4m3fn.safetensors" |
|
|
|
model = T2VSynthMochiModel( |
|
device_id=0, |
|
world_size=1, |
|
local_rank=0, |
|
vae_stats_path=f"{MOCHI_DIR}/vae_stats.json", |
|
vae_checkpoint_path=VAE_CHECKPOINT_PATH, |
|
dit_config_path=MODEL_CONFIG_PATH, |
|
dit_checkpoint_path=MODEL_CHECKPOINT_PATH, |
|
) |
|
|
|
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): |
|
if linear_steps is None: |
|
linear_steps = num_steps // 2 |
|
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] |
|
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps |
|
quadratic_steps = num_steps - linear_steps |
|
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2) |
|
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2) |
|
const = quadratic_coef * (linear_steps ** 2) |
|
quadratic_sigma_schedule = [ |
|
quadratic_coef * (i ** 2) + linear_coef * i + const |
|
for i in range(linear_steps, num_steps) |
|
] |
|
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] |
|
sigma_schedule = [1.0 - x for x in sigma_schedule] |
|
return sigma_schedule |
|
|
|
def generate_video( |
|
prompt, |
|
negative_prompt, |
|
width, |
|
height, |
|
num_frames, |
|
seed, |
|
cfg_scale, |
|
num_inference_steps, |
|
): |
|
load_model() |
|
|
|
|
|
|
|
sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025) |
|
|
|
|
|
|
|
|
|
|
|
cfg_schedule = [cfg_scale] * num_inference_steps |
|
|
|
args = { |
|
"height": height, |
|
"width": width, |
|
"num_frames": num_frames, |
|
"mochi_args": { |
|
"sigma_schedule": sigma_schedule, |
|
"cfg_schedule": cfg_schedule, |
|
"num_inference_steps": num_inference_steps, |
|
"batch_cfg": False, |
|
}, |
|
"prompt": [prompt], |
|
"negative_prompt": [negative_prompt], |
|
"seed": seed, |
|
} |
|
|
|
final_frames = None |
|
for cur_progress, frames, finished in tqdm(model.run(args, stream_results=True), total=num_inference_steps + 1): |
|
final_frames = frames |
|
|
|
assert isinstance(final_frames, np.ndarray) |
|
assert final_frames.dtype == np.float32 |
|
|
|
final_frames = rearrange(final_frames, "t b h w c -> b t h w c") |
|
final_frames = final_frames[0] |
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4") |
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
frame_paths = [] |
|
for i, frame in enumerate(final_frames): |
|
frame = (frame * 255).astype(np.uint8) |
|
frame_img = Image.fromarray(frame) |
|
frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png") |
|
frame_img.save(frame_path) |
|
frame_paths.append(frame_path) |
|
|
|
frame_pattern = os.path.join(tmpdir, "frame_%04d.png") |
|
ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}" |
|
os.system(ffmpeg_cmd) |
|
|
|
json_path = os.path.splitext(output_path)[0] + ".json" |
|
with open(json_path, "w") as f: |
|
json.dump(args, f, indent=4) |
|
|
|
return output_path |
|
|
|
|
|
@click.command() |
|
@click.option("--prompt", default=""" |
|
a high-motion drone POV flying at high speed through a vast desert environment, with dynamic camera movements capturing sweeping sand dunes, |
|
rocky terrain, and the occasional dry brush. The camera smoothly glides over the rugged landscape, weaving between towering rock formations and |
|
diving low across the sand. As the drone zooms forward, the motion gradually slows down, shifting into a close-up, hyper-detailed shot of a spider |
|
resting on a sunlit rock. The scene emphasizes cinematic motion, natural lighting, and intricate texture details on both the rock and the spider’s body, |
|
with a shallow depth of field to focus on the fine details of the spider’s legs and the rough surface beneath it. The atmosphere should feel immersive and alive, |
|
with the wind subtly blowing sand grains across the frame.""" |
|
, required=False, help="Prompt for video generation.") |
|
@click.option( |
|
"--negative_prompt", default="", help="Negative prompt for video generation." |
|
) |
|
@click.option("--width", default=848, type=int, help="Width of the video.") |
|
@click.option("--height", default=480, type=int, help="Height of the video.") |
|
@click.option("--num_frames", default=163, type=int, help="Number of frames.") |
|
@click.option("--seed", default=12345, type=int, help="Random seed.") |
|
@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.") |
|
@click.option( |
|
"--num_steps", default=64, type=int, help="Number of inference steps." |
|
) |
|
@click.option("--model_dir", required=True, help="Path to the model directory.") |
|
def generate_cli( |
|
prompt, |
|
negative_prompt, |
|
width, |
|
height, |
|
num_frames, |
|
seed, |
|
cfg_scale, |
|
num_steps, |
|
model_dir, |
|
): |
|
set_model_path(model_dir) |
|
output = generate_video( |
|
prompt, |
|
negative_prompt, |
|
width, |
|
height, |
|
num_frames, |
|
seed, |
|
cfg_scale, |
|
num_steps, |
|
) |
|
click.echo(f"Video generated at: {output}") |
|
|
|
if __name__ == "__main__": |
|
generate_cli() |
|
|