|
from dataclasses import dataclass |
|
from pathlib import Path |
|
import pathlib |
|
from typing import Dict, Any, Optional, Tuple |
|
import asyncio |
|
import base64 |
|
import io |
|
import pprint |
|
import logging |
|
import random |
|
import traceback |
|
import os |
|
import numpy as np |
|
import torch |
|
from diffusers import LTXPipeline, LTXImageToVideoPipeline |
|
from PIL import Image |
|
|
|
from varnish import Varnish |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MAX_LARGE_SIDE = 1280 |
|
MAX_SMALL_SIDE = 768 |
|
MAX_FRAMES = (8 * 21) + 1 |
|
|
|
|
|
def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory): |
|
""" |
|
Recursively rename all '.wut' files to '.pth' in the given directory |
|
|
|
Args: |
|
directory (str): Path to the directory to process |
|
""" |
|
|
|
directory = os.path.abspath(directory) |
|
|
|
|
|
for root, _, files in os.walk(directory): |
|
for filename in files: |
|
if filename.endswith('.wut'): |
|
|
|
old_path = os.path.join(root, filename) |
|
|
|
new_filename = filename.replace('.wut', '.pth') |
|
new_path = os.path.join(root, new_filename) |
|
|
|
try: |
|
os.rename(old_path, new_path) |
|
print(f"Renamed: {old_path} -> {new_path}") |
|
except OSError as e: |
|
print(f"Error renaming {old_path}: {e}") |
|
|
|
def print_directory_structure(startpath): |
|
"""Print the directory structure starting from the given path.""" |
|
for root, dirs, files in os.walk(startpath): |
|
level = root.replace(startpath, '').count(os.sep) |
|
indent = ' ' * 4 * level |
|
logger.info(f"{indent}{os.path.basename(root)}/") |
|
subindent = ' ' * 4 * (level + 1) |
|
for f in files: |
|
logger.info(f"{subindent}{f}") |
|
|
|
logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):") |
|
apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository") |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class GenerationConfig: |
|
"""Configuration for video generation""" |
|
|
|
|
|
prompt: str = "" |
|
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres" |
|
|
|
|
|
|
|
width: int = 768 |
|
height: int = 416 |
|
|
|
|
|
|
|
|
|
num_frames: int = (8 * 14) + 1 |
|
|
|
guidance_scale: float = 5.0 |
|
num_inference_steps: int = 30 |
|
|
|
|
|
seed: int = -1 |
|
|
|
|
|
fps: int = 30 |
|
double_num_frames: bool = False |
|
super_resolution: bool = False |
|
|
|
grain_amount: float = 0.0 |
|
|
|
|
|
enable_audio: bool = False |
|
audio_prompt: str = "" |
|
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" |
|
|
|
def validate_and_adjust(self) -> 'GenerationConfig': |
|
"""Validate and adjust parameters to meet constraints""" |
|
|
|
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or |
|
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)): |
|
|
|
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE |
|
|
|
|
|
total_pixels = self.width * self.height |
|
if total_pixels > MAX_TOTAL_PIXELS: |
|
scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5 |
|
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32)) |
|
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32)) |
|
else: |
|
|
|
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32)) |
|
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32)) |
|
|
|
|
|
k = (self.num_frames - 1) // 8 |
|
self.num_frames = min((k * 8) + 1, MAX_FRAMES) |
|
|
|
|
|
if self.seed == -1: |
|
self.seed = random.randint(0, 2**32 - 1) |
|
|
|
return self |
|
|
|
class EndpointHandler: |
|
"""Handles video generation requests using LTX models and Varnish post-processing""" |
|
|
|
def __init__(self, model_path: str = ""): |
|
"""Initialize the handler with LTX models and Varnish |
|
|
|
Args: |
|
model_path: Path to LTX model weights |
|
""" |
|
|
|
|
|
|
|
|
|
self.text_to_video = LTXPipeline.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
self.image_to_video = LTXImageToVideoPipeline.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.varnish = Varnish( |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
model_base_dir="/repository/varnish", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enable_mmaudio=False, |
|
) |
|
|
|
async def process_frames( |
|
self, |
|
frames: torch.Tensor, |
|
config: GenerationConfig |
|
) -> tuple[str, dict]: |
|
"""Post-process generated frames using Varnish |
|
|
|
Args: |
|
frames: Generated video frames tensor |
|
config: Generation configuration |
|
|
|
Returns: |
|
Tuple of (video data URI, metadata dictionary) |
|
""" |
|
try: |
|
|
|
result = await self.varnish( |
|
input_data=frames, |
|
fps=config.fps, |
|
double_num_frames=config.double_num_frames, |
|
super_resolution=config.super_resolution, |
|
grain_amount=config.grain_amount, |
|
enable_audio=config.enable_audio, |
|
audio_prompt=config.audio_prompt, |
|
audio_negative_prompt=config.audio_negative_prompt, |
|
) |
|
|
|
|
|
video_uri = await result.write( |
|
type="data-uri", |
|
quality=17 |
|
) |
|
|
|
|
|
metadata = { |
|
"width": result.metadata.width, |
|
"height": result.metadata.height, |
|
"num_frames": result.metadata.frame_count, |
|
"fps": result.metadata.fps, |
|
"duration": result.metadata.duration, |
|
"seed": config.seed, |
|
} |
|
|
|
return video_uri, metadata |
|
|
|
except Exception as e: |
|
logger.error(f"Error in process_frames: {str(e)}") |
|
raise RuntimeError(f"Failed to process frames: {str(e)}") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process incoming requests for video generation |
|
|
|
Args: |
|
data: Request data containing: |
|
- inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image) |
|
- parameters (dict): |
|
- prompt (required, string): list of concepts to keep in the video. |
|
- negative_prompt (optional, string): list of concepts to ignore in the video. |
|
- width (optional, int, default to 768): width, or horizontal size in pixels. |
|
- height (optional, int, default to 512): height, or vertical size in pixels. |
|
- num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame. |
|
- guidance_scale (optional, float, default to 7.5): Guidance scale |
|
- num_inference_steps (optional, int, default to 50): number of inference steps |
|
- seed (optional, int, default to -1): set a random number generator seed, -1 means random seed. |
|
- fps (optional, int, default to 24): FPS of the final video |
|
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE |
|
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN |
|
- grain_amount (optional, float): amount of film grain to add to the output video |
|
- enable_audio (optional, bool): automatically generate an audio track |
|
- audio_prompt (optional, str): prompt to use for the audio generation (concepts to add) |
|
- audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore) |
|
Returns: |
|
Dictionary containing: |
|
- video: Base64 encoded MP4 data URI |
|
- content-type: MIME type |
|
- metadata: Generation metadata |
|
""" |
|
inputs = data.get("inputs", dict()) |
|
|
|
input_prompt = inputs.get("prompt", "") |
|
input_image = inputs.get("image") |
|
|
|
params = data.get("parameters", dict()) |
|
|
|
if not input_image and not input_prompt: |
|
raise ValueError("Either prompt or image must be provided") |
|
|
|
if input_prompt: |
|
logger.info(f"Prompt: {input_prompt}") |
|
|
|
logger.info(f"Raw parameters:") |
|
pprint.pprint(params) |
|
|
|
|
|
config = GenerationConfig( |
|
|
|
prompt=input_prompt, |
|
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt), |
|
|
|
|
|
width=params.get("width", GenerationConfig.width), |
|
height=params.get("height", GenerationConfig.height), |
|
num_frames=params.get("num_frames", GenerationConfig.num_frames), |
|
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale), |
|
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps), |
|
|
|
|
|
seed=params.get("seed", GenerationConfig.seed), |
|
|
|
|
|
fps=params.get("fps", GenerationConfig.fps), |
|
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), |
|
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), |
|
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount), |
|
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio), |
|
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt), |
|
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt), |
|
).validate_and_adjust() |
|
|
|
logger.info(f"Global request settings:") |
|
pprint.pprint(config) |
|
|
|
try: |
|
with torch.no_grad(): |
|
|
|
random.seed(config.seed) |
|
np.random.seed(config.seed) |
|
generator = torch.manual_seed(config.seed) |
|
|
|
|
|
generation_kwargs = { |
|
|
|
"prompt": config.prompt, |
|
"negative_prompt": config.negative_prompt, |
|
|
|
|
|
"width": config.width, |
|
"height": config.height, |
|
"num_frames": config.num_frames, |
|
"guidance_scale": config.guidance_scale, |
|
"num_inference_steps": config.num_inference_steps, |
|
|
|
|
|
"output_type": "pt", |
|
"generator": generator |
|
} |
|
|
|
|
|
|
|
|
|
if input_image: |
|
|
|
if input_image.startswith('data:'): |
|
input_image = image_data.split(',', 1)[1] |
|
image_bytes = base64.b64decode(input_image) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
generation_kwargs["image"] = image |
|
frames = self.image_to_video(**generation_kwargs).frames |
|
else: |
|
frames = self.text_to_video(**generation_kwargs).frames |
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config)) |
|
|
|
return { |
|
"video": video_uri, |
|
"content-type": "video/mp4", |
|
"metadata": metadata |
|
} |
|
|
|
except Exception as e: |
|
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}" |
|
print(message) |
|
raise RuntimeError(message) |