jbilcke-hf HF staff commited on
Commit
1a6f91c
1 Parent(s): 5ed7cbf

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +86 -31
handler.py CHANGED
@@ -4,6 +4,10 @@ from diffusers import LTXPipeline, LTXImageToVideoPipeline
4
  from PIL import Image
5
  import base64
6
  import io
 
 
 
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
@@ -26,6 +30,50 @@ class EndpointHandler:
26
  # Enable memory optimizations
27
  self.text_to_video.enable_model_cpu_offload()
28
  self.image_to_video.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
  """Process the input data and generate video using LTX.
@@ -35,12 +83,14 @@ class EndpointHandler:
35
  - prompt (str): Text description for video generation
36
  - image (Optional[str]): Base64 encoded image for image-to-video generation
37
  - num_frames (Optional[int]): Number of frames to generate (default: 24)
 
38
  - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
39
  - num_inference_steps (Optional[int]): Number of inference steps (default: 50)
40
 
41
  Returns:
42
  Dict[str, Any]: Dictionary containing:
43
- - frames: List of base64 encoded frames
 
44
  """
45
  # Extract parameters
46
  prompt = data.get("prompt")
@@ -49,6 +99,7 @@ class EndpointHandler:
49
 
50
  # Get optional parameters with defaults
51
  num_frames = data.get("num_frames", 24)
 
52
  guidance_scale = data.get("guidance_scale", 7.5)
53
  num_inference_steps = data.get("num_inference_steps", 50)
54
 
@@ -56,37 +107,41 @@ class EndpointHandler:
56
  image_data = data.get("image")
57
 
58
  try:
59
- if image_data:
60
- # Decode base64 image
61
- image_bytes = base64.b64decode(image_data)
62
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
63
-
64
- # Generate video from image
65
- output = self.image_to_video(
66
- prompt=prompt,
67
- image=image,
68
- num_frames=num_frames,
69
- guidance_scale=guidance_scale,
70
- num_inference_steps=num_inference_steps
71
- )
72
- else:
73
- # Generate video from text only
74
- output = self.text_to_video(
75
- prompt=prompt,
76
- num_frames=num_frames,
77
- guidance_scale=guidance_scale,
78
- num_inference_steps=num_inference_steps
79
- )
 
 
 
80
 
81
- # Convert frames to base64
82
- frames = []
83
- for frame in output.frames[0]: # First element contains the frames
84
- buffer = io.BytesIO()
85
- frame.save(buffer, format="PNG")
86
- frame_base64 = base64.b64encode(buffer.getvalue()).decode()
87
- frames.append(frame_base64)
88
 
89
- return {"frames": frames}
 
 
 
90
 
91
  except Exception as e:
92
- raise RuntimeError(f"Error generating video: {str(e)}")
 
4
  from PIL import Image
5
  import base64
6
  import io
7
+ import tempfile
8
+ import numpy as np
9
+ from moviepy.editor import ImageSequenceClip
10
+ import os
11
 
12
  class EndpointHandler:
13
  def __init__(self, path: str = ""):
 
30
  # Enable memory optimizations
31
  self.text_to_video.enable_model_cpu_offload()
32
  self.image_to_video.enable_model_cpu_offload()
33
+
34
+ # Set default FPS
35
+ self.fps = 24
36
+
37
+ def _create_video_file(self, images: torch.Tensor, fps: int = 24) -> bytes:
38
+ """Convert frames to an MP4 video file.
39
+
40
+ Args:
41
+ images (torch.Tensor): Generated frames tensor
42
+ fps (int): Frames per second for the output video
43
+
44
+ Returns:
45
+ bytes: MP4 video file content
46
+ """
47
+ # Convert tensor to numpy array
48
+ video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
49
+ video_np = (video_np * 255).astype(np.uint8)
50
+
51
+ # Get dimensions
52
+ height, width = video_np.shape[1:3]
53
+
54
+ # Create temporary file
55
+ output_path = tempfile.mktemp(suffix=".mp4")
56
+
57
+ try:
58
+ # Create video clip and write to file
59
+ clip = ImageSequenceClip(list(video_np), fps=fps)
60
+ resized = clip.resize((width, height))
61
+ resized.write_videofile(output_path, codec="libx264", audio=False)
62
+
63
+ # Read the video file
64
+ with open(output_path, "rb") as f:
65
+ video_content = f.read()
66
+
67
+ return video_content
68
+
69
+ finally:
70
+ # Cleanup
71
+ if os.path.exists(output_path):
72
+ os.remove(output_path)
73
+
74
+ # Clear memory
75
+ del video_np
76
+ torch.cuda.empty_cache()
77
 
78
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
79
  """Process the input data and generate video using LTX.
 
83
  - prompt (str): Text description for video generation
84
  - image (Optional[str]): Base64 encoded image for image-to-video generation
85
  - num_frames (Optional[int]): Number of frames to generate (default: 24)
86
+ - fps (Optional[int]): Frames per second (default: 24)
87
  - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
88
  - num_inference_steps (Optional[int]): Number of inference steps (default: 50)
89
 
90
  Returns:
91
  Dict[str, Any]: Dictionary containing:
92
+ - video: Base64 encoded MP4 video
93
+ - content-type: MIME type of the video
94
  """
95
  # Extract parameters
96
  prompt = data.get("prompt")
 
99
 
100
  # Get optional parameters with defaults
101
  num_frames = data.get("num_frames", 24)
102
+ fps = data.get("fps", self.fps)
103
  guidance_scale = data.get("guidance_scale", 7.5)
104
  num_inference_steps = data.get("num_inference_steps", 50)
105
 
 
107
  image_data = data.get("image")
108
 
109
  try:
110
+ with torch.no_grad():
111
+ if image_data:
112
+ # Decode base64 image
113
+ image_bytes = base64.b64decode(image_data)
114
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
115
+
116
+ # Generate video from image
117
+ output = self.image_to_video(
118
+ prompt=prompt,
119
+ image=image,
120
+ num_frames=num_frames,
121
+ guidance_scale=guidance_scale,
122
+ num_inference_steps=num_inference_steps,
123
+ output_type="pt"
124
+ ).images
125
+ else:
126
+ # Generate video from text only
127
+ output = self.text_to_video(
128
+ prompt=prompt,
129
+ num_frames=num_frames,
130
+ guidance_scale=guidance_scale,
131
+ num_inference_steps=num_inference_steps,
132
+ output_type="pt"
133
+ ).images
134
 
135
+ # Convert frames to video file
136
+ video_content = self._create_video_file(output, fps=fps)
137
+
138
+ # Encode video to base64
139
+ video_base64 = base64.b64encode(video_content).decode('utf-8')
 
 
140
 
141
+ return {
142
+ "video": video_base64,
143
+ "content-type": "video/mp4"
144
+ }
145
 
146
  except Exception as e:
147
+ raise RuntimeError(f"Error generating video: {str(e)}")