jbilcke-hf HF staff commited on
Commit
d585ae1
1 Parent(s): 0e79ca6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -30
handler.py CHANGED
@@ -98,6 +98,11 @@ class GenerationConfig:
98
 
99
  grain_amount: float = 0.0
100
 
 
 
 
 
 
101
  def validate_and_adjust(self) -> 'GenerationConfig':
102
  """Validate and adjust parameters to meet constraints"""
103
  # Round dimensions to nearest multiple of 32
@@ -148,7 +153,6 @@ class EndpointHandler:
148
  output_format="mp4",
149
  output_codec="h264",
150
  output_quality=17,
151
- enable_mmaudio=False,
152
  model_base_dir="/repository/varnish",
153
  )
154
 
@@ -167,14 +171,6 @@ class EndpointHandler:
167
  Tuple of (video data URI, metadata dictionary)
168
  """
169
  try:
170
- logger.info(f"Original frames shape: {frames.shape}")
171
-
172
- # Remove batch dimension if present
173
- if len(frames.shape) == 5:
174
- frames = frames.squeeze(0) # Remove batch dimension
175
-
176
- logger.info(f"Processed frames shape: {frames.shape}")
177
-
178
  # Process video with Varnish
179
  result = await self.varnish(
180
  input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
@@ -182,6 +178,9 @@ class EndpointHandler:
182
  double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
183
  super_resolution=config.grain_amount_config, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
184
  grain_amount_config.grain_amount,
 
 
 
185
  )
186
 
187
  # Convert to data URI
@@ -228,6 +227,9 @@ class EndpointHandler:
228
  - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
229
  - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
230
  - grain_amount (optional, float): amount of film grain to add to the output video
 
 
 
231
  Returns:
232
  Dictionary containing:
233
  - video: Base64 encoded MP4 data URI
@@ -270,6 +272,9 @@ class EndpointHandler:
270
  double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
271
  super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
272
  grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
 
 
 
273
  ).validate_and_adjust()
274
 
275
  logger.info(f"Global request settings:")
@@ -316,33 +321,15 @@ class EndpointHandler:
316
  frames = self.image_to_video(**generation_kwargs).frames
317
  else:
318
  frames = self.text_to_video(**generation_kwargs).frames
319
-
320
-
321
- # Log original shape
322
- logger.info(f"Original frames shape: {frames.shape}")
323
-
324
- # Remove batch dimension if present
325
- if len(frames.shape) == 5:
326
- frames = frames.squeeze(0) # Remove batch dimension
327
-
328
- logger.info(f"Processed frames shape: {frames.shape}")
329
-
330
- # Ensure we have the correct shape
331
- if len(frames.shape) != 4:
332
- raise ValueError(f"Expected tensor of shape [frames, channels, height, width], got shape {frames.shape}")
333
-
334
- # Post-process frames
335
 
336
  try:
337
  loop = asyncio.get_event_loop()
338
  except RuntimeError:
339
  loop = asyncio.new_event_loop()
340
  asyncio.set_event_loop(loop)
341
-
342
- video_uri, metadata = loop.run_until_complete(
343
- self.process_frames(frames, config)
344
- )
345
-
346
  return {
347
  "video": video_uri,
348
  "content-type": "video/mp4",
 
98
 
99
  grain_amount: float = 0.0
100
 
101
+ # audio settings
102
+ enable_audio: bool = False # Whether to generate audio
103
+ audio_prompt: str = "" # Text prompt for audio generation
104
+ audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
105
+
106
  def validate_and_adjust(self) -> 'GenerationConfig':
107
  """Validate and adjust parameters to meet constraints"""
108
  # Round dimensions to nearest multiple of 32
 
153
  output_format="mp4",
154
  output_codec="h264",
155
  output_quality=17,
 
156
  model_base_dir="/repository/varnish",
157
  )
158
 
 
171
  Tuple of (video data URI, metadata dictionary)
172
  """
173
  try:
 
 
 
 
 
 
 
 
174
  # Process video with Varnish
175
  result = await self.varnish(
176
  input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
 
178
  double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
179
  super_resolution=config.grain_amount_config, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
180
  grain_amount_config.grain_amount,
181
+ enable_audio=config.enable_audio,
182
+ audio_prompt=config.audio_prompt,
183
+ audio_negative_prompt=config.audio_negative_prompt,
184
  )
185
 
186
  # Convert to data URI
 
227
  - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
228
  - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
229
  - grain_amount (optional, float): amount of film grain to add to the output video
230
+ - enable_audio (optional, bool): automatically generate an audio track
231
+ - audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
232
+ - audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
233
  Returns:
234
  Dictionary containing:
235
  - video: Base64 encoded MP4 data URI
 
272
  double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
273
  super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
274
  grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
275
+ enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
276
+ audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
277
+ audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
278
  ).validate_and_adjust()
279
 
280
  logger.info(f"Global request settings:")
 
321
  frames = self.image_to_video(**generation_kwargs).frames
322
  else:
323
  frames = self.text_to_video(**generation_kwargs).frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  try:
326
  loop = asyncio.get_event_loop()
327
  except RuntimeError:
328
  loop = asyncio.new_event_loop()
329
  asyncio.set_event_loop(loop)
330
+
331
+ video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
332
+
 
 
333
  return {
334
  "video": video_uri,
335
  "content-type": "video/mp4",