jbilcke-hf HF staff commited on
Commit
f68983c
1 Parent(s): 827505c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -14
handler.py CHANGED
@@ -22,9 +22,9 @@ logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
  # Constraints
25
- MAX_WIDTH = 1280
26
- MAX_HEIGHT = 720
27
- MAX_FRAMES = 257
28
 
29
  # this is only a temporary solution (famous last words)
30
  def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory):
@@ -78,12 +78,13 @@ class GenerationConfig:
78
  negative_prompt: str = "saturated, overlit, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
79
 
80
  # video model settings (will be used during generation of the initial raw video clip)
81
- width: int = 768
82
- height: int = 512
83
 
84
  # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
85
  # The value must be a multiple of 8, plus 1 frame.
86
- num_frames: int = 129
 
87
 
88
  guidance_scale: float = 7.5
89
  num_inference_steps: int = 50
@@ -92,7 +93,7 @@ class GenerationConfig:
92
  seed: int = -1 # -1 means random seed
93
 
94
  # varnish settings (will be used for post-processing after the raw video clip has been generated
95
- fps: int = 24 # FPS of the final video (only applied at the the very end, when converting to mp4)
96
  double_num_frames: bool = True # if True, the number of frames will be multiplied by 2 using RIFE
97
  super_resolution: bool = True # if True, the resolution will be multiplied by 2 using Real_ESRGAN
98
 
@@ -105,19 +106,31 @@ class GenerationConfig:
105
 
106
  def validate_and_adjust(self) -> 'GenerationConfig':
107
  """Validate and adjust parameters to meet constraints"""
108
- # Round dimensions to nearest multiple of 32
109
- self.width = max(32, min(MAX_WIDTH, round(self.width / 32) * 32))
110
- self.height = max(32, min(MAX_HEIGHT, round(self.height / 32) * 32))
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # Adjust number of frames to be in format 8k + 1
113
  k = (self.num_frames - 1) // 8
114
- num_frames = min((k * 8) + 1, MAX_FRAMES)
115
-
116
-
117
  # Set random seed if not specified
118
  if self.seed == -1:
119
  self.seed = random.randint(0, 2**32 - 1)
120
-
121
  return self
122
 
123
  class EndpointHandler:
 
22
  logger = logging.getLogger(__name__)
23
 
24
  # Constraints
25
+ MAX_LARGE_SIDE = 1280
26
+ MAX_SMALL_SIDE = 720
27
+ MAX_FRAMES = (8 x 21) + 1 # visual glitches appear after about 169 frames, so we cap it
28
 
29
  # this is only a temporary solution (famous last words)
30
  def apply_dirty_hack_to_patch_file_extensions_and_bypass_filter(directory):
 
78
  negative_prompt: str = "saturated, overlit, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
79
 
80
  # video model settings (will be used during generation of the initial raw video clip)
81
+ width: int = 768 # max is 1280 but we use a lower value
82
+ height: int = 416 # max is 720 but we use a lower value
83
 
84
  # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
85
  # The value must be a multiple of 8, plus 1 frame.
86
+ # visual glitches appear after about 169 frames, so we don't need more actually
87
+ num_frames: int = (8 * 14) + 1
88
 
89
  guidance_scale: float = 7.5
90
  num_inference_steps: int = 50
 
93
  seed: int = -1 # -1 means random seed
94
 
95
  # varnish settings (will be used for post-processing after the raw video clip has been generated
96
+ fps: int = 30 # FPS of the final video (only applied at the the very end, when converting to mp4)
97
  double_num_frames: bool = True # if True, the number of frames will be multiplied by 2 using RIFE
98
  super_resolution: bool = True # if True, the resolution will be multiplied by 2 using Real_ESRGAN
99
 
 
106
 
107
  def validate_and_adjust(self) -> 'GenerationConfig':
108
  """Validate and adjust parameters to meet constraints"""
109
+ # First check if it's one of our explicitly allowed resolutions
110
+ if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
111
+ (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
112
+ # For other resolutions, ensure total pixels don't exceed max
113
+ MAX_TOTAL_PIXELS = (MAX_SMALL_SIDE * MAX_LARGE_SIDE- # or 921600 = 1280 * 720
114
+
115
+ # If total pixels exceed maximum, scale down proportionally
116
+ total_pixels = self.width * self.height
117
+ if total_pixels > MAX_TOTAL_PIXELS:
118
+ scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
119
+ self.width = max(32, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
120
+ self.height = max(32, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
121
+ else:
122
+ # Round dimensions to nearest multiple of 32
123
+ self.width = max(32, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
124
+ self.height = max(32, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
125
 
126
  # Adjust number of frames to be in format 8k + 1
127
  k = (self.num_frames - 1) // 8
128
+ self.num_frames = min((k * 8) + 1, MAX_FRAMES)
129
+
 
130
  # Set random seed if not specified
131
  if self.seed == -1:
132
  self.seed = random.randint(0, 2**32 - 1)
133
+
134
  return self
135
 
136
  class EndpointHandler: