jbilcke-hf HF staff commited on
Commit
dfc94bc
·
1 Parent(s): ac45732

upgrade finetrainers + freeze datasets

Browse files
finetrainers/data/dataset.py CHANGED
@@ -970,59 +970,9 @@ def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
970
  image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
971
  return image
972
 
973
- def _preprocess_video(video) -> torch.Tensor:
974
- import torch
975
- import numpy as np
976
-
977
- # For decord VideoReader
978
- if hasattr(video, 'get_batch') and 'decord' in str(type(video)):
979
- video = video.get_batch(list(range(len(video))))
980
- video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
981
- return video
982
-
983
- # For torchvision VideoReader
984
- elif 'torchvision.io.video_reader' in str(type(video)):
985
- # Use the correct iteration pattern for torchvision.io.VideoReader
986
- frames = []
987
- try:
988
- # First seek to the beginning
989
- video.seek(0)
990
-
991
- # Then collect frames by iterating
992
- for _ in range(30): # Try to get a reasonable number of frames
993
- try:
994
- frame_dict = next(video)
995
- frame = frame_dict["data"] # Extract the tensor data from the dict
996
- frames.append(frame)
997
- except StopIteration:
998
- break
999
- except Exception as e:
1000
- print(f"Error iterating VideoReader: {e}")
1001
-
1002
- if frames:
1003
- # In torchvision.io.VideoReader, frames are already in [C, H, W] format
1004
- # We need to stack and convert to [B, C, H, W]
1005
- stacked_frames = torch.stack(frames)
1006
- # Normalize to [-1, 1]
1007
- stacked_frames = stacked_frames.float() / 127.5 - 1.0
1008
- return stacked_frames
1009
-
1010
- # If we couldn't get frames, create a dummy tensor
1011
- print("Failed to get frames, creating dummy tensor")
1012
- return torch.zeros(16, 3, 512, 768).float()
1013
-
1014
- # For list of PIL images
1015
- elif isinstance(video, list) and len(video) > 0 and hasattr(video[0], 'convert'):
1016
- frames = []
1017
- for img in video:
1018
- img_tensor = torch.from_numpy(np.array(img.convert("RGB"))).float()
1019
- frames.append(img_tensor)
1020
-
1021
- video = torch.stack(frames)
1022
- video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
1023
- return video
1024
-
1025
- # Unknown type
1026
- else:
1027
- print(f"Unknown video type: {type(video)}")
1028
- return torch.zeros(16, 3, 512, 768).float()
 
970
  image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
971
  return image
972
 
973
+
974
+ def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
975
+ video = video.get_batch(list(range(len(video))))
976
+ video = video.permute(0, 3, 1, 2).contiguous()
977
+ video = video.float() / 127.5 - 1.0
978
+ return video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/trainer/sft_trainer/trainer.py CHANGED
@@ -694,13 +694,14 @@ class SFTTrainer:
694
  # 3. Cleanup & log artifacts
695
  parallel_backend.wait_for_everyone()
696
 
 
 
 
697
  # Remove all hooks that might have been added during pipeline initialization to the models
 
698
  pipeline.remove_all_hooks()
699
  del pipeline
700
-
701
- utils.free_memory()
702
- memory_statistics = utils.get_memory_statistics()
703
- logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
704
  torch.cuda.reset_peak_memory_stats(parallel_backend.device)
705
 
706
  # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
@@ -788,7 +789,7 @@ class SFTTrainer:
788
 
789
  def _init_trackers(self) -> None:
790
  # TODO(aryan): handle multiple trackers
791
- trackers = ["wandb"]
792
  experiment_name = self.args.tracker_name or "finetrainers-experiment"
793
  self.state.parallel_backend.initialize_trackers(
794
  trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
@@ -836,7 +837,6 @@ class SFTTrainer:
836
  utils.synchronize_device()
837
 
838
  def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
839
- parallel_backend = self.state.parallel_backend
840
  module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
841
 
842
  if not final_validation:
@@ -871,7 +871,6 @@ class SFTTrainer:
871
  enable_tiling=self.args.enable_tiling,
872
  enable_model_cpu_offload=self.args.enable_model_cpu_offload,
873
  training=False,
874
- device=parallel_backend.device,
875
  )
876
 
877
  # Load the LoRA weights if performing LoRA finetuning
@@ -880,7 +879,8 @@ class SFTTrainer:
880
 
881
  components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
882
  self._set_components(components)
883
- self._move_components_to_device(list(components.values()))
 
884
  return pipeline
885
 
886
  def _prepare_data(
@@ -923,17 +923,12 @@ class SFTTrainer:
923
  else:
924
  logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
925
 
926
- # TODO(aryan): This needs to be revisited. For some reason, the tests did not detect that self.transformer
927
- # had become None after this but should have been loaded back from the checkpoint.
928
- # parallel_backend = self.state.parallel_backend
929
- # train_state = self.state.train_state
930
- # self.checkpointer.save(
931
- # train_state.step,
932
- # force=True,
933
- # _device=parallel_backend.device,
934
- # _is_main_process=parallel_backend.is_main_process,
935
- # )
936
- # self._delete_components(component_names=["transformer", "unet"])
937
 
938
  if self.args.precomputation_once:
939
  consume_fn = preprocessor.consume_once
@@ -974,8 +969,8 @@ class SFTTrainer:
974
  self._delete_components(component_names)
975
  del latent_components, component_names, component_modules
976
 
977
- # self.checkpointer.load()
978
- # self.transformer = self.checkpointer.states["model"].model[0]
979
 
980
  return condition_iterator, latent_iterator
981
 
 
694
  # 3. Cleanup & log artifacts
695
  parallel_backend.wait_for_everyone()
696
 
697
+ memory_statistics = utils.get_memory_statistics()
698
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
699
+
700
  # Remove all hooks that might have been added during pipeline initialization to the models
701
+ module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
702
  pipeline.remove_all_hooks()
703
  del pipeline
704
+ self._delete_components(module_names)
 
 
 
705
  torch.cuda.reset_peak_memory_stats(parallel_backend.device)
706
 
707
  # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
 
789
 
790
  def _init_trackers(self) -> None:
791
  # TODO(aryan): handle multiple trackers
792
+ trackers = [self.args.report_to]
793
  experiment_name = self.args.tracker_name or "finetrainers-experiment"
794
  self.state.parallel_backend.initialize_trackers(
795
  trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
 
837
  utils.synchronize_device()
838
 
839
  def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
 
840
  module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
841
 
842
  if not final_validation:
 
871
  enable_tiling=self.args.enable_tiling,
872
  enable_model_cpu_offload=self.args.enable_model_cpu_offload,
873
  training=False,
 
874
  )
875
 
876
  # Load the LoRA weights if performing LoRA finetuning
 
879
 
880
  components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
881
  self._set_components(components)
882
+ if not self.args.enable_model_cpu_offload:
883
+ self._move_components_to_device(list(components.values()))
884
  return pipeline
885
 
886
  def _prepare_data(
 
923
  else:
924
  logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
925
 
926
+ parallel_backend = self.state.parallel_backend
927
+ if parallel_backend.world_size == 1:
928
+ self._move_components_to_device([self.transformer], "cpu")
929
+ utils.free_memory()
930
+ utils.synchronize_device()
931
+ torch.cuda.reset_peak_memory_stats(parallel_backend.device)
 
 
 
 
 
932
 
933
  if self.args.precomputation_once:
934
  consume_fn = preprocessor.consume_once
 
969
  self._delete_components(component_names)
970
  del latent_components, component_names, component_modules
971
 
972
+ if parallel_backend.world_size == 1:
973
+ self._move_components_to_device([self.transformer])
974
 
975
  return condition_iterator, latent_iterator
976
 
requirements.txt CHANGED
@@ -7,6 +7,10 @@ torch==2.5.1
7
  torchvision==0.20.1
8
  torchao==0.6.1
9
 
 
 
 
 
10
  huggingface_hub
11
  hf_transfer>=0.1.8
12
  diffusers @ git+https://github.com/huggingface/diffusers.git@main
 
7
  torchvision==0.20.1
8
  torchao==0.6.1
9
 
10
+ # datasets 3.4.0 replaces decord by torchvision
11
+ # let's free it for now
12
+ datasets==3.3.2
13
+
14
  huggingface_hub
15
  hf_transfer>=0.1.8
16
  diffusers @ git+https://github.com/huggingface/diffusers.git@main
requirements_without_flash_attention.txt CHANGED
@@ -8,6 +8,10 @@ torch==2.5.1
8
  torchvision==0.20.1
9
  torchao==0.6.1
10
 
 
 
 
 
11
  huggingface_hub
12
  hf_transfer>=0.1.8
13
  diffusers @ git+https://github.com/huggingface/diffusers.git@main
 
8
  torchvision==0.20.1
9
  torchao==0.6.1
10
 
11
+ # datasets 3.4.0 replaces decord by torchvision
12
+ # let's free it for now
13
+ datasets==3.3.2
14
+
15
  huggingface_hub
16
  hf_transfer>=0.1.8
17
  diffusers @ git+https://github.com/huggingface/diffusers.git@main