File size: 7,804 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import json
import os
import tempfile
import time
import click
import numpy as np
#import ray
from einops import rearrange
from PIL import Image
from tqdm import tqdm
from mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
model = None
model_path = "weights"
def noexcept(f):
try:
return f()
except:
pass
# class MochiWrapper:
# def __init__(self, *, num_workers, **actor_kwargs):
# super().__init__()
# RemoteClass = ray.remote(T2VSynthMochiModel)
# self.workers = [
# RemoteClass.options(num_gpus=1).remote(
# device_id=0, world_size=num_workers, local_rank=i, **actor_kwargs
# )
# for i in range(num_workers)
# ]
# # Ensure the __init__ method has finished on all workers
# for worker in self.workers:
# ray.get(worker.__ray_ready__.remote())
# self.is_loaded = True
# def __call__(self, args):
# work_refs = [
# worker.run.remote(args, i == 0) for i, worker in enumerate(self.workers)
# ]
# try:
# for result in work_refs[0]:
# yield ray.get(result)
# # Handle the (very unlikely) edge-case where a worker that's not the 1st one
# # fails (don't want an uncaught error)
# for result in work_refs[1:]:
# ray.get(result)
# except Exception as e:
# # Get exception from other workers
# for ref in work_refs[1:]:
# noexcept(lambda: ray.get(ref))
# raise e
def set_model_path(path):
global model_path
model_path = path
def load_model():
global model, model_path
if model is None:
#ray.init()
MOCHI_DIR = model_path
VAE_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_vae_bf16.safetensors"
MODEL_CONFIG_PATH = f"{MOCHI_DIR}/dit-config.yaml"
MODEL_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_dit_fp8_e4m3fn.safetensors"
model = T2VSynthMochiModel(
device_id=0,
world_size=1,
local_rank=0,
vae_stats_path=f"{MOCHI_DIR}/vae_stats.json",
vae_checkpoint_path=VAE_CHECKPOINT_PATH,
dit_config_path=MODEL_CONFIG_PATH,
dit_checkpoint_path=MODEL_CHECKPOINT_PATH,
)
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
const = quadratic_coef * (linear_steps ** 2)
quadratic_sigma_schedule = [
quadratic_coef * (i ** 2) + linear_coef * i + const
for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
def generate_video(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_inference_steps,
):
load_model()
# sigma_schedule should be a list of floats of length (num_inference_steps + 1),
# such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
# cfg_schedule should be a list of floats of length num_inference_steps.
# For simplicity, we just use the same cfg scale at all timesteps,
# but more optimal schedules may use varying cfg, e.g:
# [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
cfg_schedule = [cfg_scale] * num_inference_steps
args = {
"height": height,
"width": width,
"num_frames": num_frames,
"mochi_args": {
"sigma_schedule": sigma_schedule,
"cfg_schedule": cfg_schedule,
"num_inference_steps": num_inference_steps,
"batch_cfg": False,
},
"prompt": [prompt],
"negative_prompt": [negative_prompt],
"seed": seed,
}
final_frames = None
for cur_progress, frames, finished in tqdm(model.run(args, stream_results=True), total=num_inference_steps + 1):
final_frames = frames
assert isinstance(final_frames, np.ndarray)
assert final_frames.dtype == np.float32
final_frames = rearrange(final_frames, "t b h w c -> b t h w c")
final_frames = final_frames[0]
os.makedirs("outputs", exist_ok=True)
output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4")
with tempfile.TemporaryDirectory() as tmpdir:
frame_paths = []
for i, frame in enumerate(final_frames):
frame = (frame * 255).astype(np.uint8)
frame_img = Image.fromarray(frame)
frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png")
frame_img.save(frame_path)
frame_paths.append(frame_path)
frame_pattern = os.path.join(tmpdir, "frame_%04d.png")
ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}"
os.system(ffmpeg_cmd)
json_path = os.path.splitext(output_path)[0] + ".json"
with open(json_path, "w") as f:
json.dump(args, f, indent=4)
return output_path
@click.command()
@click.option("--prompt", default="""
a high-motion drone POV flying at high speed through a vast desert environment, with dynamic camera movements capturing sweeping sand dunes,
rocky terrain, and the occasional dry brush. The camera smoothly glides over the rugged landscape, weaving between towering rock formations and
diving low across the sand. As the drone zooms forward, the motion gradually slows down, shifting into a close-up, hyper-detailed shot of a spider
resting on a sunlit rock. The scene emphasizes cinematic motion, natural lighting, and intricate texture details on both the rock and the spider’s body,
with a shallow depth of field to focus on the fine details of the spider’s legs and the rough surface beneath it. The atmosphere should feel immersive and alive,
with the wind subtly blowing sand grains across the frame."""
, required=False, help="Prompt for video generation.")
@click.option(
"--negative_prompt", default="", help="Negative prompt for video generation."
)
@click.option("--width", default=848, type=int, help="Width of the video.")
@click.option("--height", default=480, type=int, help="Height of the video.")
@click.option("--num_frames", default=163, type=int, help="Number of frames.")
@click.option("--seed", default=12345, type=int, help="Random seed.")
@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.")
@click.option(
"--num_steps", default=64, type=int, help="Number of inference steps."
)
@click.option("--model_dir", required=True, help="Path to the model directory.")
def generate_cli(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
model_dir,
):
set_model_path(model_dir)
output = generate_video(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
)
click.echo(f"Video generated at: {output}")
if __name__ == "__main__":
generate_cli()
|