File size: 19,495 Bytes
be2df75
 
2039f5a
be2df75
 
132e8c4
 
ea52235
be2df75
b5d7f4c
a5265d3
6789b6e
1a6f91c
be2df75
 
 
2fa2e84
 
e349e43
 
 
 
132e8c4
be2df75
f68983c
a36a3bb
017b989
be2df75
9d84818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fd04e8
 
9d84818
458a627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d84818
 
 
2039f5a
1fd04e8
 
5e8a3d7
0cf5bce
1fd04e8
a36a3bb
 
 
1fd04e8
 
 
f68983c
 
60c8ea5
 
 
1fd04e8
60c8ea5
1fd04e8
 
be2df75
 
1fd04e8
f68983c
017b989
 
1fd04e8
017b989
be2df75
d585ae1
 
 
 
 
be2df75
 
f68983c
 
 
 
996f8c3
f68983c
 
 
 
 
a36a3bb
 
f68983c
a36a3bb
 
 
be2df75
 
 
f68983c
 
be2df75
 
 
f68983c
be2df75
85f39ae
132e8c4
be2df75
d35cde0
be2df75
 
 
132e8c4
be2df75
132e8c4
be2df75
 
d35cde0
be2df75
132e8c4
be2df75
132e8c4
 
 
 
be2df75
132e8c4
 
 
be2df75
 
 
ef15707
be2df75
f6dd4f3
 
58774ec
5008035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6dd4f3
 
be2df75
 
 
 
 
 
ef15707
 
be2df75
 
ef15707
 
be2df75
ef15707
99df0e2
 
 
0e79ca6
 
1fd04e8
fc5df44
26effe4
d585ae1
 
 
99df0e2
 
 
 
1fd04e8
6108eb4
99df0e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132e8c4
1fd04e8
132e8c4
be2df75
132e8c4
 
be2df75
1fd04e8
 
 
 
 
 
 
60c8ea5
1fd04e8
 
60c8ea5
1fd04e8
 
 
d585ae1
 
 
132e8c4
be2df75
 
 
 
132e8c4
1fd04e8
 
 
 
 
 
61690e3
201ac66
61690e3
 
 
 
 
1fd04e8
 
132e8c4
be2df75
 
1fd04e8
c98ab1a
1fd04e8
 
 
 
 
b71870b
1fd04e8
 
 
 
1976186
1fd04e8
 
 
 
 
 
d585ae1
 
 
be2df75
1fd04e8
 
 
e349e43
132e8c4
1a6f91c
be2df75
 
 
 
b5d7f4c
51d9ba1
ef15707
1fd04e8
6c74560
 
1fd04e8
 
b212177
6c74560
 
 
 
1fd04e8
 
b5d7f4c
 
ef15707
827505c
 
ea52235
be2df75
1fd04e8
458a627
 
 
 
 
 
f6dd4f3
1a6f91c
f6dd4f3
28cbc54
be2df75
28cbc54
 
 
 
d585ae1
 
 
1a6f91c
be2df75
f6dd4f3
 
1a6f91c
132e8c4
 
a5265d3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
from dataclasses import dataclass
from pathlib import Path
import pathlib
from typing import Dict, Any, Optional, Tuple
import asyncio
import base64
import io
import pprint
import logging
import random
import traceback
import os
import numpy as np
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image

from varnish import Varnish

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constraints
MAX_LARGE_SIDE = 1280
MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it

# this is only a temporary solution (famous last words)
def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory):
    """
    Recursively rename all '.wut' files to '.pth' in the given directory
    
    Args:
        directory (str): Path to the directory to process
    """
    # Convert the directory path to absolute path
    directory = os.path.abspath(directory)
    
    # Walk through directory and its subdirectories
    for root, _, files in os.walk(directory):
        for filename in files:
            if filename.endswith('.wut'):
                # Get full path of the file
                old_path = os.path.join(root, filename)
                # Create new filename by replacing the extension
                new_filename = filename.replace('.wut', '.pth')
                new_path = os.path.join(root, new_filename)
                
                try:
                    os.rename(old_path, new_path)
                    print(f"Renamed: {old_path} -> {new_path}")
                except OSError as e:
                    print(f"Error renaming {old_path}: {e}")

def print_directory_structure(startpath):
    """Print the directory structure starting from the given path."""
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * level
        logger.info(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            logger.info(f"{subindent}{f}")

logger.info("💡 Applying a dirty hack (patch ""/repository"" to fix file extensions):")
apply_dirty_hack_to_patch_file_extensions_and_bypass_filter("/repository")

#logger.info("💡 Printing directory structure of ""/repository"":")
#print_directory_structure("/repository")


def process_input_image(image_data: str, target_width: int, target_height: int) -> Image.Image:
    """
    Process input image from base64, resize and crop to target dimensions
    
    Args:
        image_data: Base64 encoded image data
        target_width: Desired width
        target_height: Desired height
        
    Returns:
        Processed PIL Image
    """
    try:
        # Handle data URI format
        if image_data.startswith('data:'):
            image_data = image_data.split(',', 1)[1]
            
        # Decode base64
        image_bytes = base64.b64decode(image_data)
        image = Image.open(io.BytesIO(image_bytes))
        
        # Convert to RGB if necessary
        if image.mode not in ('RGB', 'RGBA'):
            image = image.convert('RGB')
        elif image.mode == 'RGBA':
            # Handle transparency by compositing on white background
            background = Image.new('RGB', image.size, (255, 255, 255))
            background.paste(image, mask=image.split()[3])
            image = background
            
        # Calculate target aspect ratio
        target_aspect = target_width / target_height
        
        # Get current dimensions
        orig_width, orig_height = image.size
        orig_aspect = orig_width / orig_height
        
        # Calculate dimensions for resizing
        if orig_aspect > target_aspect:
            # Image is wider than target
            new_height = target_height
            new_width = int(target_height * orig_aspect)
        else:
            # Image is taller than target
            new_width = target_width
            new_height = int(target_width / orig_aspect)
            
        # Resize image
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
        
        # Center crop to target dimensions
        left = (new_width - target_width) // 2
        top = (new_height - target_height) // 2
        right = left + target_width
        bottom = top + target_height
        
        image = image.crop((left, top, right, bottom))
        
        return image
        
    except Exception as e:
        raise ValueError(f"Failed to process input image: {str(e)}")

@dataclass
class GenerationConfig:
    """Configuration for video generation"""

    # general content settings
    prompt: str = ""
    negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"

    # video model settings (will be used during generation of the initial raw video clip)
    # we use small values to make things a bit faster
    width: int = 768
    height: int = 416

    # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
    # The value must be a multiple of 8, plus 1 frame.
    # visual glitches appear after about 169 frames, so we don't need more actually
    num_frames: int = (8 * 14) + 1

    # values between 3.0 and 4.0 are nice
    guidance_scale: float = 3.5
    
    num_inference_steps: int = 50

    # reproducible generation settings
    seed: int = -1  # -1 means random seed

    # varnish settings (will be used for post-processing after the raw video clip has been generated
    fps: int = 30 # FPS of the final video (only applied at the the very end, when converting to mp4)
    double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
    super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
    
    grain_amount: float = 0.0 # be careful, adding film grian can negatively impact video compression

    # audio settings
    enable_audio: bool = False  # Whether to generate audio
    audio_prompt: str = ""  # Text prompt for audio generation
    audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation

    def validate_and_adjust(self) -> 'GenerationConfig':
        """Validate and adjust parameters to meet constraints"""
        # First check if it's one of our explicitly allowed resolutions
        if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or 
                (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
            # For other resolutions, ensure total pixels don't exceed max
            MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
            
            # If total pixels exceed maximum, scale down proportionally
            total_pixels = self.width * self.height
            if total_pixels > MAX_TOTAL_PIXELS:
                scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
                self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
                self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
            else:
                # Round dimensions to nearest multiple of 32
                self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
                self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
        
        # Adjust number of frames to be in format 8k + 1
        k = (self.num_frames - 1) // 8
        self.num_frames = min((k * 8) + 1, MAX_FRAMES)
    
        # Set random seed if not specified
        if self.seed == -1:
            self.seed = random.randint(0, 2**32 - 1)
    
        return self

class EndpointHandler:
    """Handles video generation requests using LTX models and Varnish post-processing"""
    
    def __init__(self, model_path: str = ""):
        """Initialize the handler with LTX models and Varnish

        Args:
            model_path: Path to LTX model weights
        """
        # Enable TF32 for potential speedup on Ampere GPUs
        #torch.backends.cuda.matmul.allow_tf32 = True
        
        # Initialize models with bfloat16 precision
        self.text_to_video = LTXPipeline.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16
        ).to("cuda")
        
        self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16
        ).to("cuda")

        # Enable CPU offload for memory efficiency
        #self.text_to_video.enable_model_cpu_offload()
        #self.image_to_video.enable_model_cpu_offload()

        # Initialize Varnish for post-processing
        self.varnish = Varnish(
            device="cuda" if torch.cuda.is_available() else "cpu",
            model_base_dir="/repository/varnish",

            # there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
            # not sure how to fix that.. :/
            #
            # it says:
            #   File "dist-packages/varnish.py", line 152, in __init__
            #     self._setup_mmaudio()
            #   File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
            #     net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
            #                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            #   File "dist-packages/torch/serialization.py", line 1384, in load
            #     return _legacy_load(
            #            ^^^^^^^^^^^^^
            #   File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
            #     magic_number = pickle_module.load(f, **pickle_load_args)
            #                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            # _pickle.UnpicklingError: invalid load key, '<'.
            enable_mmaudio=False,
        )

    async def process_frames(
        self,
        frames: torch.Tensor,
        config: GenerationConfig
    ) -> tuple[str, dict]:
        """Post-process generated frames using Varnish
        
        Args:
            frames: Generated video frames tensor
            config: Generation configuration
            
        Returns:
            Tuple of (video data URI, metadata dictionary)
        """
        try:
            # Process video with Varnish
            result = await self.varnish(
                input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
                fps=config.fps, # this is the FPS of the final output video. This number can be used by Varnish to calculate the duration of a clip ((using frames * factor) / fps etc)
                double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
                super_resolution=config.super_resolution, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
                grain_amount=config.grain_amount,
                enable_audio=config.enable_audio,
                audio_prompt=config.audio_prompt,
                audio_negative_prompt=config.audio_negative_prompt, 
            )
            
            # Convert to data URI
            video_uri = await result.write(
                type="data-uri",
                quality=17
            )
            
            # Collect metadata
            metadata = {
                "width": result.metadata.width,
                "height": result.metadata.height,
                "num_frames": result.metadata.frame_count,
                "fps": result.metadata.fps,
                "duration": result.metadata.duration,
                "seed": config.seed,
            }
            
            return video_uri, metadata
    
        except Exception as e:
            logger.error(f"Error in process_frames: {str(e)}")
            raise RuntimeError(f"Failed to process frames: {str(e)}")


    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process incoming requests for video generation
        
        Args:
            data: Request data containing:
                - inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
                - parameters (dict):
                    - prompt (required, string): list of concepts to keep in the video.
                    - negative_prompt (optional, string): list of concepts to ignore in the video.
                    - width (optional, int, default to 768): width, or horizontal size in pixels.
                    - height (optional, int, default to 512): height, or vertical size in pixels.
                    - num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
                    - guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
                    - num_inference_steps (optional, int, default to 50): number of inference steps
                    - seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
                    - fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
                    - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
                    - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
                    - grain_amount (optional, float): amount of film grain to add to the output video
                    - enable_audio (optional, bool): automatically generate an audio track
                    - audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
                    - audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
        Returns:
            Dictionary containing:
                - video: Base64 encoded MP4 data URI
                - content-type: MIME type
                - metadata: Generation metadata
        """
        inputs = data.get("inputs", dict())
        
        input_prompt = inputs.get("prompt", "")
        input_image = inputs.get("image")
        
        params = data.get("parameters", dict())

        if not input_image and not input_prompt:
            raise ValueError("Either prompt or image must be provided")
      
        if input_prompt:
            logger.info(f"Prompt: {input_prompt}")
                   
        logger.info(f"Raw parameters:")
        pprint.pprint(params)

        # Create and validate configuration
        config = GenerationConfig(
            # general content settings
            prompt=input_prompt,
            negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),

            # video model settings (will be used during generation of the initial raw video clip)
            width=params.get("width", GenerationConfig.width),
            height=params.get("height", GenerationConfig.height),
            num_frames=params.get("num_frames", GenerationConfig.num_frames),
            guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
            num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),

            # reproducible generation settings
            seed=params.get("seed", GenerationConfig.seed),
            
            # varnish settings (will be used for post-processing after the raw video clip has been generated)
            fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
            double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
            super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
            grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
            enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
            audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
            audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
        ).validate_and_adjust()
        
        logger.info(f"Global request settings:")
        pprint.pprint(config)

        try:
            with torch.no_grad():
                # Set random seeds
                random.seed(config.seed)
                np.random.seed(config.seed)
                generator = torch.manual_seed(config.seed)
                
                # Prepare generation parameters for the video model (we omit params that are destined to Varnish, or things like the seed which is set externally)
                generation_kwargs = {
                   # general content settings
                    "prompt": config.prompt,
                    "negative_prompt": config.negative_prompt,
        
                    # video model settings (will be used during generation of the initial raw video clip)
                    "width": config.width,
                    "height": config.height,
                    "num_frames": config.num_frames,
                    "guidance_scale": config.guidance_scale,
                    "num_inference_steps": config.num_inference_steps,
 
                    # constants
                    "output_type": "pt",
                    "generator": generator
                }
                #logger.info(f"Video model generation settings:")
                #pprint.pprint(generation_kwargs)
                
                # Check if image-to-video generation is requested
                if input_image:
                    processed_image = process_input_image(
                        input_image,
                        config.width,
                        config.height
                    )
                    generation_kwargs["image"] = processed_image
                    frames = self.image_to_video(**generation_kwargs).frames
                else:
                    frames = self.text_to_video(**generation_kwargs).frames

                try:
                    loop = asyncio.get_event_loop()
                except RuntimeError:
                    loop = asyncio.new_event_loop()
                    asyncio.set_event_loop(loop)
                
                video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
                
                return {
                    "video": video_uri,
                    "content-type": "video/mp4",
                    "metadata": metadata
                }

        except Exception as e:
            message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
            print(message)
            raise RuntimeError(message)