smoothieAI commited on
Commit
ba50290
·
1 Parent(s): 0eee060

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +8 -7
pipeline.py CHANGED
@@ -773,19 +773,20 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
773
 
774
  # save frames
775
  if output_path is not None:
776
- output_batch_size = 10 #this prevents out of memory errors with large videos
777
- num_frames = latents.size(2) #latents' shape is [batch, channels, frames, height, width]
 
 
 
778
  for start_idx in range(0, num_frames, output_batch_size):
779
  end_idx = min(start_idx + output_batch_size, num_frames)
780
  video_tensor = self.decode_latents(latents[:, :, start_idx:end_idx, :, :])
781
  video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
782
 
783
  for batch_idx, frame_batch in enumerate(video):
784
- for frame in frame_batch[0][0]:
785
- digit_substring = ''.join(filter(str.isdigit, output_path))
786
- frame_number = int(digit_substring) + start_idx + batch_idx
787
- new_output_path = output_path.replace(digit_substring, str(frame_number).zfill(5), 1)
788
- frame.save(new_output_path)
789
  return output_path
790
 
791
  # Post-processing
 
773
 
774
  # save frames
775
  if output_path is not None:
776
+ output_batch_size = 10 # prevents out of memory errors with large videos
777
+ num_frames = latents.size(2) # latents' shape is [batch, channels, frames, height, width]
778
+ num_digits = output_path.count('#') # count the number of '#' characters
779
+ frame_format = output_path.replace('#' * num_digits, '{:0' + str(num_digits) + 'd}')
780
+
781
  for start_idx in range(0, num_frames, output_batch_size):
782
  end_idx = min(start_idx + output_batch_size, num_frames)
783
  video_tensor = self.decode_latents(latents[:, :, start_idx:end_idx, :, :])
784
  video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
785
 
786
  for batch_idx, frame_batch in enumerate(video):
787
+ for frame_idx, frame in enumerate(frame_batch[0][0]):
788
+ frame_number = start_idx + batch_idx * len(frame_batch[0][0]) + frame_idx
789
+ frame.save(frame_format.format(frame_number))
 
 
790
  return output_path
791
 
792
  # Post-processing