Spaces:
Running
Running
Commit
·
dfc94bc
1
Parent(s):
ac45732
upgrade finetrainers + freeze datasets
Browse files
finetrainers/data/dataset.py
CHANGED
@@ -970,59 +970,9 @@ def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
|
|
970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
971 |
return image
|
972 |
|
973 |
-
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
|
979 |
-
video = video.get_batch(list(range(len(video))))
|
980 |
-
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
981 |
-
return video
|
982 |
-
|
983 |
-
# For torchvision VideoReader
|
984 |
-
elif 'torchvision.io.video_reader' in str(type(video)):
|
985 |
-
# Use the correct iteration pattern for torchvision.io.VideoReader
|
986 |
-
frames = []
|
987 |
-
try:
|
988 |
-
# First seek to the beginning
|
989 |
-
video.seek(0)
|
990 |
-
|
991 |
-
# Then collect frames by iterating
|
992 |
-
for _ in range(30): # Try to get a reasonable number of frames
|
993 |
-
try:
|
994 |
-
frame_dict = next(video)
|
995 |
-
frame = frame_dict["data"] # Extract the tensor data from the dict
|
996 |
-
frames.append(frame)
|
997 |
-
except StopIteration:
|
998 |
-
break
|
999 |
-
except Exception as e:
|
1000 |
-
print(f"Error iterating VideoReader: {e}")
|
1001 |
-
|
1002 |
-
if frames:
|
1003 |
-
# In torchvision.io.VideoReader, frames are already in [C, H, W] format
|
1004 |
-
# We need to stack and convert to [B, C, H, W]
|
1005 |
-
stacked_frames = torch.stack(frames)
|
1006 |
-
# Normalize to [-1, 1]
|
1007 |
-
stacked_frames = stacked_frames.float() / 127.5 - 1.0
|
1008 |
-
return stacked_frames
|
1009 |
-
|
1010 |
-
# If we couldn't get frames, create a dummy tensor
|
1011 |
-
print("Failed to get frames, creating dummy tensor")
|
1012 |
-
return torch.zeros(16, 3, 512, 768).float()
|
1013 |
-
|
1014 |
-
# For list of PIL images
|
1015 |
-
elif isinstance(video, list) and len(video) > 0 and hasattr(video[0], 'convert'):
|
1016 |
-
frames = []
|
1017 |
-
for img in video:
|
1018 |
-
img_tensor = torch.from_numpy(np.array(img.convert("RGB"))).float()
|
1019 |
-
frames.append(img_tensor)
|
1020 |
-
|
1021 |
-
video = torch.stack(frames)
|
1022 |
-
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
1023 |
-
return video
|
1024 |
-
|
1025 |
-
# Unknown type
|
1026 |
-
else:
|
1027 |
-
print(f"Unknown video type: {type(video)}")
|
1028 |
-
return torch.zeros(16, 3, 512, 768).float()
|
|
|
970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
971 |
return image
|
972 |
|
973 |
+
|
974 |
+
def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
|
975 |
+
video = video.get_batch(list(range(len(video))))
|
976 |
+
video = video.permute(0, 3, 1, 2).contiguous()
|
977 |
+
video = video.float() / 127.5 - 1.0
|
978 |
+
return video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
@@ -694,13 +694,14 @@ class SFTTrainer:
|
|
694 |
# 3. Cleanup & log artifacts
|
695 |
parallel_backend.wait_for_everyone()
|
696 |
|
|
|
|
|
|
|
697 |
# Remove all hooks that might have been added during pipeline initialization to the models
|
|
|
698 |
pipeline.remove_all_hooks()
|
699 |
del pipeline
|
700 |
-
|
701 |
-
utils.free_memory()
|
702 |
-
memory_statistics = utils.get_memory_statistics()
|
703 |
-
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
704 |
torch.cuda.reset_peak_memory_stats(parallel_backend.device)
|
705 |
|
706 |
# Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
|
@@ -788,7 +789,7 @@ class SFTTrainer:
|
|
788 |
|
789 |
def _init_trackers(self) -> None:
|
790 |
# TODO(aryan): handle multiple trackers
|
791 |
-
trackers = [
|
792 |
experiment_name = self.args.tracker_name or "finetrainers-experiment"
|
793 |
self.state.parallel_backend.initialize_trackers(
|
794 |
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
|
@@ -836,7 +837,6 @@ class SFTTrainer:
|
|
836 |
utils.synchronize_device()
|
837 |
|
838 |
def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
|
839 |
-
parallel_backend = self.state.parallel_backend
|
840 |
module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
|
841 |
|
842 |
if not final_validation:
|
@@ -871,7 +871,6 @@ class SFTTrainer:
|
|
871 |
enable_tiling=self.args.enable_tiling,
|
872 |
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
|
873 |
training=False,
|
874 |
-
device=parallel_backend.device,
|
875 |
)
|
876 |
|
877 |
# Load the LoRA weights if performing LoRA finetuning
|
@@ -880,7 +879,8 @@ class SFTTrainer:
|
|
880 |
|
881 |
components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
|
882 |
self._set_components(components)
|
883 |
-
self.
|
|
|
884 |
return pipeline
|
885 |
|
886 |
def _prepare_data(
|
@@ -923,17 +923,12 @@ class SFTTrainer:
|
|
923 |
else:
|
924 |
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
|
925 |
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
# force=True,
|
933 |
-
# _device=parallel_backend.device,
|
934 |
-
# _is_main_process=parallel_backend.is_main_process,
|
935 |
-
# )
|
936 |
-
# self._delete_components(component_names=["transformer", "unet"])
|
937 |
|
938 |
if self.args.precomputation_once:
|
939 |
consume_fn = preprocessor.consume_once
|
@@ -974,8 +969,8 @@ class SFTTrainer:
|
|
974 |
self._delete_components(component_names)
|
975 |
del latent_components, component_names, component_modules
|
976 |
|
977 |
-
|
978 |
-
|
979 |
|
980 |
return condition_iterator, latent_iterator
|
981 |
|
|
|
694 |
# 3. Cleanup & log artifacts
|
695 |
parallel_backend.wait_for_everyone()
|
696 |
|
697 |
+
memory_statistics = utils.get_memory_statistics()
|
698 |
+
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
699 |
+
|
700 |
# Remove all hooks that might have been added during pipeline initialization to the models
|
701 |
+
module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
|
702 |
pipeline.remove_all_hooks()
|
703 |
del pipeline
|
704 |
+
self._delete_components(module_names)
|
|
|
|
|
|
|
705 |
torch.cuda.reset_peak_memory_stats(parallel_backend.device)
|
706 |
|
707 |
# Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
|
|
|
789 |
|
790 |
def _init_trackers(self) -> None:
|
791 |
# TODO(aryan): handle multiple trackers
|
792 |
+
trackers = [self.args.report_to]
|
793 |
experiment_name = self.args.tracker_name or "finetrainers-experiment"
|
794 |
self.state.parallel_backend.initialize_trackers(
|
795 |
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
|
|
|
837 |
utils.synchronize_device()
|
838 |
|
839 |
def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
|
|
|
840 |
module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
|
841 |
|
842 |
if not final_validation:
|
|
|
871 |
enable_tiling=self.args.enable_tiling,
|
872 |
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
|
873 |
training=False,
|
|
|
874 |
)
|
875 |
|
876 |
# Load the LoRA weights if performing LoRA finetuning
|
|
|
879 |
|
880 |
components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
|
881 |
self._set_components(components)
|
882 |
+
if not self.args.enable_model_cpu_offload:
|
883 |
+
self._move_components_to_device(list(components.values()))
|
884 |
return pipeline
|
885 |
|
886 |
def _prepare_data(
|
|
|
923 |
else:
|
924 |
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
|
925 |
|
926 |
+
parallel_backend = self.state.parallel_backend
|
927 |
+
if parallel_backend.world_size == 1:
|
928 |
+
self._move_components_to_device([self.transformer], "cpu")
|
929 |
+
utils.free_memory()
|
930 |
+
utils.synchronize_device()
|
931 |
+
torch.cuda.reset_peak_memory_stats(parallel_backend.device)
|
|
|
|
|
|
|
|
|
|
|
932 |
|
933 |
if self.args.precomputation_once:
|
934 |
consume_fn = preprocessor.consume_once
|
|
|
969 |
self._delete_components(component_names)
|
970 |
del latent_components, component_names, component_modules
|
971 |
|
972 |
+
if parallel_backend.world_size == 1:
|
973 |
+
self._move_components_to_device([self.transformer])
|
974 |
|
975 |
return condition_iterator, latent_iterator
|
976 |
|
requirements.txt
CHANGED
@@ -7,6 +7,10 @@ torch==2.5.1
|
|
7 |
torchvision==0.20.1
|
8 |
torchao==0.6.1
|
9 |
|
|
|
|
|
|
|
|
|
10 |
huggingface_hub
|
11 |
hf_transfer>=0.1.8
|
12 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
|
|
7 |
torchvision==0.20.1
|
8 |
torchao==0.6.1
|
9 |
|
10 |
+
# datasets 3.4.0 replaces decord by torchvision
|
11 |
+
# let's free it for now
|
12 |
+
datasets==3.3.2
|
13 |
+
|
14 |
huggingface_hub
|
15 |
hf_transfer>=0.1.8
|
16 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
requirements_without_flash_attention.txt
CHANGED
@@ -8,6 +8,10 @@ torch==2.5.1
|
|
8 |
torchvision==0.20.1
|
9 |
torchao==0.6.1
|
10 |
|
|
|
|
|
|
|
|
|
11 |
huggingface_hub
|
12 |
hf_transfer>=0.1.8
|
13 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
|
|
8 |
torchvision==0.20.1
|
9 |
torchao==0.6.1
|
10 |
|
11 |
+
# datasets 3.4.0 replaces decord by torchvision
|
12 |
+
# let's free it for now
|
13 |
+
datasets==3.3.2
|
14 |
+
|
15 |
huggingface_hub
|
16 |
hf_transfer>=0.1.8
|
17 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|