Commit
•
d585ae1
1
Parent(s):
0e79ca6
Update handler.py
Browse files- handler.py +17 -30
handler.py
CHANGED
@@ -98,6 +98,11 @@ class GenerationConfig:
|
|
98 |
|
99 |
grain_amount: float = 0.0
|
100 |
|
|
|
|
|
|
|
|
|
|
|
101 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
102 |
"""Validate and adjust parameters to meet constraints"""
|
103 |
# Round dimensions to nearest multiple of 32
|
@@ -148,7 +153,6 @@ class EndpointHandler:
|
|
148 |
output_format="mp4",
|
149 |
output_codec="h264",
|
150 |
output_quality=17,
|
151 |
-
enable_mmaudio=False,
|
152 |
model_base_dir="/repository/varnish",
|
153 |
)
|
154 |
|
@@ -167,14 +171,6 @@ class EndpointHandler:
|
|
167 |
Tuple of (video data URI, metadata dictionary)
|
168 |
"""
|
169 |
try:
|
170 |
-
logger.info(f"Original frames shape: {frames.shape}")
|
171 |
-
|
172 |
-
# Remove batch dimension if present
|
173 |
-
if len(frames.shape) == 5:
|
174 |
-
frames = frames.squeeze(0) # Remove batch dimension
|
175 |
-
|
176 |
-
logger.info(f"Processed frames shape: {frames.shape}")
|
177 |
-
|
178 |
# Process video with Varnish
|
179 |
result = await self.varnish(
|
180 |
input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
|
@@ -182,6 +178,9 @@ class EndpointHandler:
|
|
182 |
double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
|
183 |
super_resolution=config.grain_amount_config, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
184 |
grain_amount_config.grain_amount,
|
|
|
|
|
|
|
185 |
)
|
186 |
|
187 |
# Convert to data URI
|
@@ -228,6 +227,9 @@ class EndpointHandler:
|
|
228 |
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
|
229 |
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
|
230 |
- grain_amount (optional, float): amount of film grain to add to the output video
|
|
|
|
|
|
|
231 |
Returns:
|
232 |
Dictionary containing:
|
233 |
- video: Base64 encoded MP4 data URI
|
@@ -270,6 +272,9 @@ class EndpointHandler:
|
|
270 |
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
|
271 |
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
272 |
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
|
|
|
|
|
|
273 |
).validate_and_adjust()
|
274 |
|
275 |
logger.info(f"Global request settings:")
|
@@ -316,33 +321,15 @@ class EndpointHandler:
|
|
316 |
frames = self.image_to_video(**generation_kwargs).frames
|
317 |
else:
|
318 |
frames = self.text_to_video(**generation_kwargs).frames
|
319 |
-
|
320 |
-
|
321 |
-
# Log original shape
|
322 |
-
logger.info(f"Original frames shape: {frames.shape}")
|
323 |
-
|
324 |
-
# Remove batch dimension if present
|
325 |
-
if len(frames.shape) == 5:
|
326 |
-
frames = frames.squeeze(0) # Remove batch dimension
|
327 |
-
|
328 |
-
logger.info(f"Processed frames shape: {frames.shape}")
|
329 |
-
|
330 |
-
# Ensure we have the correct shape
|
331 |
-
if len(frames.shape) != 4:
|
332 |
-
raise ValueError(f"Expected tensor of shape [frames, channels, height, width], got shape {frames.shape}")
|
333 |
-
|
334 |
-
# Post-process frames
|
335 |
|
336 |
try:
|
337 |
loop = asyncio.get_event_loop()
|
338 |
except RuntimeError:
|
339 |
loop = asyncio.new_event_loop()
|
340 |
asyncio.set_event_loop(loop)
|
341 |
-
|
342 |
-
video_uri, metadata = loop.run_until_complete(
|
343 |
-
|
344 |
-
)
|
345 |
-
|
346 |
return {
|
347 |
"video": video_uri,
|
348 |
"content-type": "video/mp4",
|
|
|
98 |
|
99 |
grain_amount: float = 0.0
|
100 |
|
101 |
+
# audio settings
|
102 |
+
enable_audio: bool = False # Whether to generate audio
|
103 |
+
audio_prompt: str = "" # Text prompt for audio generation
|
104 |
+
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
|
105 |
+
|
106 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
107 |
"""Validate and adjust parameters to meet constraints"""
|
108 |
# Round dimensions to nearest multiple of 32
|
|
|
153 |
output_format="mp4",
|
154 |
output_codec="h264",
|
155 |
output_quality=17,
|
|
|
156 |
model_base_dir="/repository/varnish",
|
157 |
)
|
158 |
|
|
|
171 |
Tuple of (video data URI, metadata dictionary)
|
172 |
"""
|
173 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
# Process video with Varnish
|
175 |
result = await self.varnish(
|
176 |
input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
|
|
|
178 |
double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
|
179 |
super_resolution=config.grain_amount_config, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
180 |
grain_amount_config.grain_amount,
|
181 |
+
enable_audio=config.enable_audio,
|
182 |
+
audio_prompt=config.audio_prompt,
|
183 |
+
audio_negative_prompt=config.audio_negative_prompt,
|
184 |
)
|
185 |
|
186 |
# Convert to data URI
|
|
|
227 |
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
|
228 |
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
|
229 |
- grain_amount (optional, float): amount of film grain to add to the output video
|
230 |
+
- enable_audio (optional, bool): automatically generate an audio track
|
231 |
+
- audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
|
232 |
+
- audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
|
233 |
Returns:
|
234 |
Dictionary containing:
|
235 |
- video: Base64 encoded MP4 data URI
|
|
|
272 |
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
|
273 |
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
274 |
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
275 |
+
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
|
276 |
+
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
|
277 |
+
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
|
278 |
).validate_and_adjust()
|
279 |
|
280 |
logger.info(f"Global request settings:")
|
|
|
321 |
frames = self.image_to_video(**generation_kwargs).frames
|
322 |
else:
|
323 |
frames = self.text_to_video(**generation_kwargs).frames
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
try:
|
326 |
loop = asyncio.get_event_loop()
|
327 |
except RuntimeError:
|
328 |
loop = asyncio.new_event_loop()
|
329 |
asyncio.set_event_loop(loop)
|
330 |
+
|
331 |
+
video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
|
332 |
+
|
|
|
|
|
333 |
return {
|
334 |
"video": video_uri,
|
335 |
"content-type": "video/mp4",
|