jbilcke-hf HF Staff commited on
Commit
d78dede
·
1 Parent(s): 64a70c0

investigate bugs in Finetrainers

Browse files
finetrainers/dataset.py CHANGED
@@ -15,6 +15,9 @@ from torchvision import transforms
15
  from torchvision.transforms import InterpolationMode
16
  from torchvision.transforms.functional import resize
17
 
 
 
 
18
 
19
  # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
20
  # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
@@ -30,6 +33,22 @@ from .constants import ( # noqa
30
  )
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  logger = get_logger(__name__)
34
 
35
 
@@ -229,20 +248,48 @@ class ImageOrVideoDataset(Dataset):
229
  return image
230
 
231
  def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
232
- r"""
233
  Loads a single video, or latent and prompt embedding, based on initialization parameters.
234
-
235
  Returns a [F, C, H, W] video tensor.
236
  """
237
- video_reader = decord.VideoReader(uri=path.as_posix())
238
- video_num_frames = len(video_reader)
239
-
240
- indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
241
- frames = video_reader.get_batch(indices)
242
- frames = frames[: self.max_num_frames].float()
243
- frames = frames.permute(0, 3, 1, 2).contiguous()
244
- frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
245
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
@@ -264,35 +311,60 @@ class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
264
  return image
265
 
266
  def _preprocess_video(self, path: Path) -> torch.Tensor:
267
- video_reader = decord.VideoReader(uri=path.as_posix())
268
- video_num_frames = len(video_reader)
269
- #print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets)
270
- #print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames)
271
- #print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames)
272
-
273
- video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
274
 
275
- if not video_buckets:
276
- _, h, w = self.resolution_buckets[0]
277
- video_buckets = [(1, h, w)]
278
-
279
- nearest_frame_bucket = min(
280
- video_buckets,
281
- key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
282
- default=video_buckets[0],
283
- )[0]
284
-
285
- frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
286
-
287
- frames = video_reader.get_batch(frame_indices)
288
- frames = frames[:nearest_frame_bucket].float()
289
- frames = frames.permute(0, 3, 1, 2).contiguous()
290
-
291
- nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
292
- frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
293
- frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
294
-
295
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  def _find_nearest_resolution(self, height, width):
298
  nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
@@ -338,35 +410,62 @@ class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
338
  return arr
339
 
340
  def _preprocess_video(self, path: Path) -> torch.Tensor:
341
- video_reader = decord.VideoReader(uri=path.as_posix())
342
- video_num_frames = len(video_reader)
343
- print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.resolution_buckets = ", self.resolution_buckets)
344
- print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.max_num_frames = ", self.max_num_frames)
345
- print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: video_num_frames = ", video_num_frames)
346
 
347
- video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
- if not video_buckets:
350
- _, h, w = self.resolution_buckets[0]
351
- video_buckets = [(1, h, w)]
352
-
353
- nearest_frame_bucket = min(
354
- video_buckets,
355
- key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
356
- default=video_buckets[0],
357
- )[0]
358
-
359
- frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
360
-
361
- frames = video_reader.get_batch(frame_indices)
362
- frames = frames[:nearest_frame_bucket].float()
363
- frames = frames.permute(0, 3, 1, 2).contiguous()
364
-
365
- nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
366
- frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
367
- frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
368
- return frames
369
-
370
  def _find_nearest_resolution(self, height, width):
371
  nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
372
  return nearest_res[1], nearest_res[2]
 
15
  from torchvision.transforms import InterpolationMode
16
  from torchvision.transforms.functional import resize
17
 
18
+ import gc
19
+ import time
20
+ import resource
21
 
22
  # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
23
  # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
 
33
  )
34
 
35
 
36
+ # Decord is causing us some issues!
37
+ # Let's try to increase file descriptor limits to avoid this error:
38
+ #
39
+ # decord._ffi.base.DECORDError: Resource temporarily unavailable
40
+ try:
41
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
42
+ logger.info(f"Current file descriptor limits: soft={soft}, hard={hard}")
43
+
44
+ # Try to increase to hard limit if possible
45
+ if soft < hard:
46
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
47
+ new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
48
+ logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
49
+ except Exception as e:
50
+ logger.warning(f"Could not check or update file descriptor limits: {e}")
51
+
52
  logger = get_logger(__name__)
53
 
54
 
 
248
  return image
249
 
250
  def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
251
+ """
252
  Loads a single video, or latent and prompt embedding, based on initialization parameters.
 
253
  Returns a [F, C, H, W] video tensor.
254
  """
255
+ max_retries = 3
256
+ retry_delay = 1.0 # seconds
257
+
258
+ for attempt in range(max_retries):
259
+ try:
260
+ # Create video reader
261
+ video_reader = decord.VideoReader(uri=path.as_posix())
262
+ video_num_frames = len(video_reader)
263
+
264
+ # Process frames
265
+ indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
266
+ frames = video_reader.get_batch(indices)
267
+ frames = frames[: self.max_num_frames].float()
268
+ frames = frames.permute(0, 3, 1, 2).contiguous()
269
+ frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
270
+
271
+ # Explicitly clean up resources
272
+ del video_reader
273
+
274
+ # Force garbage collection occasionally
275
+ if random.random() < 0.05: # 5% chance
276
+ gc.collect()
277
+
278
+ return frames
279
+
280
+ except decord._ffi.base.DECORDError as e:
281
+ # Log the error
282
+ error_msg = str(e)
283
+ if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
284
+ logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
285
+
286
+ # Clean up and wait before retrying
287
+ gc.collect()
288
+ time.sleep(retry_delay * (attempt + 1)) # Increasing backoff
289
+ else:
290
+ # Either not a resource error or we've run out of retries
291
+ logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
292
+ raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
293
 
294
 
295
  class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
 
311
  return image
312
 
313
  def _preprocess_video(self, path: Path) -> torch.Tensor:
314
+ max_retries = 3
315
+ retry_delay = 1.0 # seconds
 
 
 
 
 
316
 
317
+ for attempt in range(max_retries):
318
+ try:
319
+ # Create video reader
320
+ video_reader = decord.VideoReader(uri=path.as_posix())
321
+ video_num_frames = len(video_reader)
322
+
323
+ # Find appropriate bucket for the video
324
+ video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
325
+
326
+ if not video_buckets:
327
+ _, h, w = self.resolution_buckets[0]
328
+ video_buckets = [(1, h, w)]
329
+
330
+ nearest_frame_bucket = min(
331
+ video_buckets,
332
+ key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
333
+ default=video_buckets[0],
334
+ )[0]
335
+
336
+ # Extract and process frames
337
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
338
+ frames = video_reader.get_batch(frame_indices)
339
+ frames = frames[:nearest_frame_bucket].float()
340
+ frames = frames.permute(0, 3, 1, 2).contiguous()
341
+
342
+ nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
343
+ frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
344
+ frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
345
+
346
+ # Explicitly clean up resources
347
+ del video_reader
348
+
349
+ # Force garbage collection occasionally
350
+ if random.random() < 0.05: # 5% chance
351
+ gc.collect()
352
+
353
+ return frames
354
+
355
+ except decord._ffi.base.DECORDError as e:
356
+ # Log the error
357
+ error_msg = str(e)
358
+ if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
359
+ logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
360
+
361
+ # Clean up and wait before retrying
362
+ gc.collect()
363
+ time.sleep(retry_delay * (attempt + 1)) # Increasing backoff
364
+ else:
365
+ # Either not a resource error or we've run out of retries
366
+ logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
367
+ raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
368
 
369
  def _find_nearest_resolution(self, height, width):
370
  nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
 
410
  return arr
411
 
412
  def _preprocess_video(self, path: Path) -> torch.Tensor:
413
+ max_retries = 3
414
+ retry_delay = 1.0 # seconds
 
 
 
415
 
416
+ for attempt in range(max_retries):
417
+ try:
418
+ # Create video reader
419
+ video_reader = decord.VideoReader(uri=path.as_posix())
420
+ video_num_frames = len(video_reader)
421
+
422
+ # Find appropriate bucket for the video
423
+ video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
424
+
425
+ if not video_buckets:
426
+ _, h, w = self.resolution_buckets[0]
427
+ video_buckets = [(1, h, w)]
428
+
429
+ nearest_frame_bucket = min(
430
+ video_buckets,
431
+ key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
432
+ default=video_buckets[0],
433
+ )[0]
434
+
435
+ # Extract and process frames
436
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
437
+ frames = video_reader.get_batch(frame_indices)
438
+ frames = frames[:nearest_frame_bucket].float()
439
+ frames = frames.permute(0, 3, 1, 2).contiguous()
440
+
441
+ # Fix: Change self.resolutions to self.resolution_buckets to match the class attribute
442
+ nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
443
+ frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
444
+ frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
445
+
446
+ # Explicitly clean up resources
447
+ del video_reader
448
+
449
+ # Force garbage collection occasionally
450
+ if random.random() < 0.05: # 5% chance
451
+ gc.collect()
452
+
453
+ return frames
454
+
455
+ except decord._ffi.base.DECORDError as e:
456
+ # Log the error
457
+ error_msg = str(e)
458
+ if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
459
+ logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
460
+
461
+ # Clean up and wait before retrying
462
+ gc.collect()
463
+ time.sleep(retry_delay * (attempt + 1)) # Increasing backoff
464
+ else:
465
+ # Either not a resource error or we've run out of retries
466
+ logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
467
+ raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def _find_nearest_resolution(self, height, width):
470
  nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
471
  return nearest_res[1], nearest_res[2]
finetrainers/trainer.py CHANGED
@@ -2,6 +2,7 @@ import json
2
  import logging
3
  import math
4
  import os
 
5
  import random
6
  from datetime import datetime, timedelta
7
  from pathlib import Path
@@ -549,6 +550,20 @@ class Trainer:
549
  def train(self) -> None:
550
  logger.info("Starting training")
551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  memory_statistics = get_memory_statistics()
553
  logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
554
 
@@ -816,9 +831,15 @@ class Trainer:
816
  progress_bar.set_postfix(logs)
817
  accelerator.log(logs, step=global_step)
818
 
 
 
 
 
819
  if global_step >= self.state.train_steps:
820
  break
821
 
 
 
822
  if num_loss_updates > 0:
823
  epoch_loss /= num_loss_updates
824
  accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
@@ -833,6 +854,13 @@ class Trainer:
833
  if should_run_validation:
834
  self.validate(global_step)
835
 
 
 
 
 
 
 
 
836
  accelerator.wait_for_everyone()
837
  if accelerator.is_main_process:
838
  transformer = unwrap_model(accelerator, self.transformer)
 
2
  import logging
3
  import math
4
  import os
5
+ import gc
6
  import random
7
  from datetime import datetime, timedelta
8
  from pathlib import Path
 
550
  def train(self) -> None:
551
  logger.info("Starting training")
552
 
553
+
554
+ # Add these lines at the beginning
555
+ if hasattr(resource, 'RLIMIT_NOFILE'):
556
+ try:
557
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
558
+ logger.info(f"Current file descriptor limits in trainer: soft={soft}, hard={hard}")
559
+ # Try to increase to hard limit if possible
560
+ if soft < hard:
561
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
562
+ new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
563
+ logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
564
+ except Exception as e:
565
+ logger.warning(f"Could not check or update file descriptor limits: {e}")
566
+
567
  memory_statistics = get_memory_statistics()
568
  logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
569
 
 
831
  progress_bar.set_postfix(logs)
832
  accelerator.log(logs, step=global_step)
833
 
834
+ if global_step % 100 == 0: # Every 100 steps
835
+ # Force garbage collection to clean up any lingering resources
836
+ gc.collect()
837
+
838
  if global_step >= self.state.train_steps:
839
  break
840
 
841
+
842
+
843
  if num_loss_updates > 0:
844
  epoch_loss /= num_loss_updates
845
  accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
 
854
  if should_run_validation:
855
  self.validate(global_step)
856
 
857
+ if epoch % 3 == 0: # Every 3 epochs
858
+ logger.info("Performing periodic resource cleanup")
859
+ free_memory()
860
+ gc.collect()
861
+ torch.cuda.empty_cache()
862
+ torch.cuda.synchronize(accelerator.device)
863
+
864
  accelerator.wait_for_everyone()
865
  if accelerator.is_main_process:
866
  transformer = unwrap_model(accelerator, self.transformer)
training/cogvideox/dataset.py CHANGED
@@ -57,7 +57,7 @@ class VideoDataset(Dataset):
57
  self.random_flip = random_flip
58
  self.image_to_video = image_to_video
59
 
60
- self.resolutions = [
61
  (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
62
  ]
63
 
@@ -295,7 +295,7 @@ class VideoDatasetWithResizing(VideoDataset):
295
  return image, frames, None
296
 
297
  def _find_nearest_resolution(self, height, width):
298
- nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
299
  return nearest_res[1], nearest_res[2]
300
 
301
 
 
57
  self.random_flip = random_flip
58
  self.image_to_video = image_to_video
59
 
60
+ self.resolution_buckets = [
61
  (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
62
  ]
63
 
 
295
  return image, frames, None
296
 
297
  def _find_nearest_resolution(self, height, width):
298
+ nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
299
  return nearest_res[1], nearest_res[2]
300
 
301