jbilcke-hf HF Staff commited on
Commit
66c6879
·
1 Parent(s): c8589f9

fix for Finetrainers

Browse files
finetrainers/dataset.py CHANGED
@@ -32,25 +32,23 @@ from .constants import ( # noqa
32
  PRECOMPUTED_LATENTS_DIR_NAME,
33
  )
34
 
35
- logger = get_logger(__name__)
36
-
37
  # Decord is causing us some issues!
38
  # Let's try to increase file descriptor limits to avoid this error:
39
  #
40
  # decord._ffi.base.DECORDError: Resource temporarily unavailable
41
  try:
42
  soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
43
- logger.info(f"Current file descriptor limits: soft={soft}, hard={hard}")
44
 
45
  # Try to increase to hard limit if possible
46
  if soft < hard:
47
  resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
48
  new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
49
- logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
50
  except Exception as e:
51
- logger.warning(f"Could not check or update file descriptor limits: {e}")
52
-
53
 
 
54
 
55
  # TODO(aryan): This needs a refactor with separation of concerns.
56
  # Images should be handled separately. Videos should be handled separately.
 
32
  PRECOMPUTED_LATENTS_DIR_NAME,
33
  )
34
 
 
 
35
  # Decord is causing us some issues!
36
  # Let's try to increase file descriptor limits to avoid this error:
37
  #
38
  # decord._ffi.base.DECORDError: Resource temporarily unavailable
39
  try:
40
  soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
41
+ print(f"Current file descriptor limits: soft={soft}, hard={hard}")
42
 
43
  # Try to increase to hard limit if possible
44
  if soft < hard:
45
  resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
46
  new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
47
+ print(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
48
  except Exception as e:
49
+ print(f"Could not check or update file descriptor limits: {e}")
 
50
 
51
+ logger = get_logger(__name__)
52
 
53
  # TODO(aryan): This needs a refactor with separation of concerns.
54
  # Images should be handled separately. Videos should be handled separately.
finetrainers/finetrainers__lib__trainer.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import math
4
+ import os
5
+ import gc
6
+ import random
7
+ from datetime import datetime, timedelta
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List
10
+
11
+ import diffusers
12
+ import torch
13
+ import torch.backends
14
+ import transformers
15
+ import wandb
16
+ from accelerate import Accelerator, DistributedType
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import (
19
+ DistributedDataParallelKwargs,
20
+ InitProcessGroupKwargs,
21
+ ProjectConfiguration,
22
+ gather_object,
23
+ set_seed,
24
+ )
25
+ from diffusers import DiffusionPipeline
26
+ from diffusers.configuration_utils import FrozenDict
27
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
28
+ from diffusers.optimization import get_scheduler
29
+ from diffusers.training_utils import cast_training_params
30
+ from diffusers.utils import export_to_video, load_image, load_video
31
+ from huggingface_hub import create_repo, upload_folder
32
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
33
+ from tqdm import tqdm
34
+
35
+ from .args import Args, validate_args
36
+ from .constants import (
37
+ FINETRAINERS_LOG_LEVEL,
38
+ PRECOMPUTED_CONDITIONS_DIR_NAME,
39
+ PRECOMPUTED_DIR_NAME,
40
+ PRECOMPUTED_LATENTS_DIR_NAME,
41
+ )
42
+ from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
43
+ from .hooks import apply_layerwise_upcasting
44
+ from .models import get_config_from_model_name
45
+ from .patches import perform_peft_patches
46
+ from .state import State
47
+ from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
48
+ from .utils.data_utils import should_perform_precomputation
49
+ from .utils.diffusion_utils import (
50
+ get_scheduler_alphas,
51
+ get_scheduler_sigmas,
52
+ prepare_loss_weights,
53
+ prepare_sigmas,
54
+ prepare_target,
55
+ )
56
+ from .utils.file_utils import string_to_filename
57
+ from .utils.hub_utils import save_model_card
58
+ from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
59
+ from .utils.model_utils import resolve_vae_cls_from_ckpt_path
60
+ from .utils.optimizer_utils import get_optimizer
61
+ from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
62
+
63
+
64
+ logger = get_logger("finetrainers")
65
+ logger.setLevel(FINETRAINERS_LOG_LEVEL)
66
+
67
+
68
+ class Trainer:
69
+ def __init__(self, args: Args) -> None:
70
+ validate_args(args)
71
+
72
+ self.args = args
73
+ self.args.seed = self.args.seed or datetime.now().year
74
+ self.state = State()
75
+
76
+ # Tokenizers
77
+ self.tokenizer = None
78
+ self.tokenizer_2 = None
79
+ self.tokenizer_3 = None
80
+
81
+ # Text encoders
82
+ self.text_encoder = None
83
+ self.text_encoder_2 = None
84
+ self.text_encoder_3 = None
85
+
86
+ # Denoisers
87
+ self.transformer = None
88
+ self.unet = None
89
+
90
+ # Autoencoders
91
+ self.vae = None
92
+
93
+ # Scheduler
94
+ self.scheduler = None
95
+
96
+ self.transformer_config = None
97
+ self.vae_config = None
98
+
99
+ self._init_distributed()
100
+ self._init_logging()
101
+ self._init_directories_and_repositories()
102
+ self._init_config_options()
103
+
104
+ # Peform any patches needed for training
105
+ if len(self.args.layerwise_upcasting_modules) > 0:
106
+ perform_peft_patches()
107
+ # TODO(aryan): handle text encoders
108
+ # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
109
+ # perform_text_encoder_patches()
110
+
111
+ self.state.model_name = self.args.model_name
112
+ self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
113
+
114
+ def prepare_dataset(self) -> None:
115
+ # TODO(aryan): Make a background process for fetching
116
+ logger.info("Initializing dataset and dataloader")
117
+
118
+ self.dataset = ImageOrVideoDatasetWithResizing(
119
+ data_root=self.args.data_root,
120
+ caption_column=self.args.caption_column,
121
+ video_column=self.args.video_column,
122
+ resolution_buckets=self.args.video_resolution_buckets,
123
+ dataset_file=self.args.dataset_file,
124
+ id_token=self.args.id_token,
125
+ remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes,
126
+ )
127
+ self.dataloader = torch.utils.data.DataLoader(
128
+ self.dataset,
129
+ batch_size=1,
130
+ sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
131
+ collate_fn=self.model_config.get("collate_fn"),
132
+ num_workers=self.args.dataloader_num_workers,
133
+ pin_memory=self.args.pin_memory,
134
+ )
135
+
136
+ def prepare_models(self) -> None:
137
+ logger.info("Initializing models")
138
+
139
+ load_components_kwargs = self._get_load_components_kwargs()
140
+ condition_components, latent_components, diffusion_components = {}, {}, {}
141
+ if not self.args.precompute_conditions:
142
+ # To download the model files first on the main process (if not already present)
143
+ # and then load the cached files afterward from the other processes.
144
+ with self.state.accelerator.main_process_first():
145
+ condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
146
+ latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
147
+ diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)
148
+
149
+ components = {}
150
+ components.update(condition_components)
151
+ components.update(latent_components)
152
+ components.update(diffusion_components)
153
+ self._set_components(components)
154
+
155
+ if self.vae is not None:
156
+ if self.args.enable_slicing:
157
+ self.vae.enable_slicing()
158
+ if self.args.enable_tiling:
159
+ self.vae.enable_tiling()
160
+
161
+ def prepare_precomputations(self) -> None:
162
+ if not self.args.precompute_conditions:
163
+ return
164
+
165
+ logger.info("Initializing precomputations")
166
+
167
+ if self.args.batch_size != 1:
168
+ raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.")
169
+
170
+ def collate_fn(batch):
171
+ latent_conditions = [x["latent_conditions"] for x in batch]
172
+ text_conditions = [x["text_conditions"] for x in batch]
173
+ batched_latent_conditions = {}
174
+ batched_text_conditions = {}
175
+ for key in list(latent_conditions[0].keys()):
176
+ if torch.is_tensor(latent_conditions[0][key]):
177
+ batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0)
178
+ else:
179
+ # TODO(aryan): implement batch sampler for precomputed latents
180
+ batched_latent_conditions[key] = [x[key] for x in latent_conditions][0]
181
+ for key in list(text_conditions[0].keys()):
182
+ if torch.is_tensor(text_conditions[0][key]):
183
+ batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0)
184
+ else:
185
+ # TODO(aryan): implement batch sampler for precomputed latents
186
+ batched_text_conditions[key] = [x[key] for x in text_conditions][0]
187
+ return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions}
188
+
189
+ cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path)
190
+ precomputation_dir = (
191
+ Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
192
+ )
193
+ should_precompute = should_perform_precomputation(precomputation_dir)
194
+ if not should_precompute:
195
+ logger.info("Precomputed conditions and latents found. Loading precomputed data.")
196
+ self.dataloader = torch.utils.data.DataLoader(
197
+ PrecomputedDataset(
198
+ data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
199
+ ),
200
+ batch_size=self.args.batch_size,
201
+ shuffle=True,
202
+ collate_fn=collate_fn,
203
+ num_workers=self.args.dataloader_num_workers,
204
+ pin_memory=self.args.pin_memory,
205
+ )
206
+ return
207
+
208
+ logger.info("Precomputed conditions and latents not found. Running precomputation.")
209
+
210
+ # At this point, no models are loaded, so we need to load and precompute conditions and latents
211
+ with self.state.accelerator.main_process_first():
212
+ condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
213
+ self._set_components(condition_components)
214
+ self._move_components_to_device()
215
+ self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3])
216
+
217
+ if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty":
218
+ logger.warning(
219
+ "Caption dropout is not supported with precomputation yet. This will be supported in the future."
220
+ )
221
+
222
+ conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
223
+ latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
224
+ conditions_dir.mkdir(parents=True, exist_ok=True)
225
+ latents_dir.mkdir(parents=True, exist_ok=True)
226
+
227
+ accelerator = self.state.accelerator
228
+
229
+ # Precompute conditions
230
+ progress_bar = tqdm(
231
+ range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
232
+ desc="Precomputing conditions",
233
+ disable=not accelerator.is_local_main_process,
234
+ )
235
+ index = 0
236
+ for i, data in enumerate(self.dataset):
237
+ if i % accelerator.num_processes != accelerator.process_index:
238
+ continue
239
+
240
+ logger.debug(
241
+ f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
242
+ )
243
+
244
+ text_conditions = self.model_config["prepare_conditions"](
245
+ tokenizer=self.tokenizer,
246
+ tokenizer_2=self.tokenizer_2,
247
+ tokenizer_3=self.tokenizer_3,
248
+ text_encoder=self.text_encoder,
249
+ text_encoder_2=self.text_encoder_2,
250
+ text_encoder_3=self.text_encoder_3,
251
+ prompt=data["prompt"],
252
+ device=accelerator.device,
253
+ dtype=self.args.transformer_dtype,
254
+ )
255
+ filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
256
+ torch.save(text_conditions, filename.as_posix())
257
+ index += 1
258
+ progress_bar.update(1)
259
+ self._delete_components()
260
+
261
+ memory_statistics = get_memory_statistics()
262
+ logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
263
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
264
+
265
+ # Precompute latents
266
+ with self.state.accelerator.main_process_first():
267
+ latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
268
+ self._set_components(latent_components)
269
+ self._move_components_to_device()
270
+ self._disable_grad_for_components([self.vae])
271
+
272
+ if self.vae is not None:
273
+ if self.args.enable_slicing:
274
+ self.vae.enable_slicing()
275
+ if self.args.enable_tiling:
276
+ self.vae.enable_tiling()
277
+
278
+ progress_bar = tqdm(
279
+ range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
280
+ desc="Precomputing latents",
281
+ disable=not accelerator.is_local_main_process,
282
+ )
283
+ index = 0
284
+ for i, data in enumerate(self.dataset):
285
+ if i % accelerator.num_processes != accelerator.process_index:
286
+ continue
287
+
288
+ logger.debug(
289
+ f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
290
+ )
291
+
292
+ latent_conditions = self.model_config["prepare_latents"](
293
+ vae=self.vae,
294
+ image_or_video=data["video"].unsqueeze(0),
295
+ device=accelerator.device,
296
+ dtype=self.args.transformer_dtype,
297
+ generator=self.state.generator,
298
+ precompute=True,
299
+ )
300
+ filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt"
301
+ torch.save(latent_conditions, filename.as_posix())
302
+ index += 1
303
+ progress_bar.update(1)
304
+ self._delete_components()
305
+
306
+ accelerator.wait_for_everyone()
307
+ logger.info("Precomputation complete")
308
+
309
+ memory_statistics = get_memory_statistics()
310
+ logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
311
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
312
+
313
+ # Update dataloader to use precomputed conditions and latents
314
+ self.dataloader = torch.utils.data.DataLoader(
315
+ PrecomputedDataset(
316
+ data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
317
+ ),
318
+ batch_size=self.args.batch_size,
319
+ shuffle=True,
320
+ collate_fn=collate_fn,
321
+ num_workers=self.args.dataloader_num_workers,
322
+ pin_memory=self.args.pin_memory,
323
+ )
324
+
325
+ def prepare_trainable_parameters(self) -> None:
326
+ logger.info("Initializing trainable parameters")
327
+
328
+ with self.state.accelerator.main_process_first():
329
+ diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
330
+ self._set_components(diffusion_components)
331
+
332
+ components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae]
333
+ self._disable_grad_for_components(components)
334
+
335
+ if self.args.training_type == "full-finetune":
336
+ logger.info("Finetuning transformer with no additional parameters")
337
+ self._enable_grad_for_components([self.transformer])
338
+ else:
339
+ logger.info("Finetuning transformer with PEFT parameters")
340
+ self._disable_grad_for_components([self.transformer])
341
+
342
+ # Layerwise upcasting must be applied before adding the LoRA adapter.
343
+ # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
344
+ # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
345
+ if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
346
+ apply_layerwise_upcasting(
347
+ self.transformer,
348
+ storage_dtype=self.args.layerwise_upcasting_storage_dtype,
349
+ compute_dtype=self.args.transformer_dtype,
350
+ skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
351
+ non_blocking=True,
352
+ )
353
+
354
+ self._move_components_to_device()
355
+
356
+ if self.args.gradient_checkpointing:
357
+ self.transformer.enable_gradient_checkpointing()
358
+
359
+ if self.args.training_type == "lora":
360
+ transformer_lora_config = LoraConfig(
361
+ r=self.args.rank,
362
+ lora_alpha=self.args.lora_alpha,
363
+ init_lora_weights=True,
364
+ target_modules=self.args.target_modules,
365
+ )
366
+ self.transformer.add_adapter(transformer_lora_config)
367
+ else:
368
+ transformer_lora_config = None
369
+
370
+ # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
371
+ # even if layerwise upcasting. Would be nice to have a test as well
372
+
373
+ self.register_saving_loading_hooks(transformer_lora_config)
374
+
375
+ def register_saving_loading_hooks(self, transformer_lora_config):
376
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
377
+ def save_model_hook(models, weights, output_dir):
378
+ if self.state.accelerator.is_main_process:
379
+ transformer_lora_layers_to_save = None
380
+
381
+ for model in models:
382
+ if isinstance(
383
+ unwrap_model(self.state.accelerator, model),
384
+ type(unwrap_model(self.state.accelerator, self.transformer)),
385
+ ):
386
+ model = unwrap_model(self.state.accelerator, model)
387
+ if self.args.training_type == "lora":
388
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
389
+ else:
390
+ raise ValueError(f"Unexpected save model: {model.__class__}")
391
+
392
+ # make sure to pop weight so that corresponding model is not saved again
393
+ if weights:
394
+ weights.pop()
395
+
396
+ if self.args.training_type == "lora":
397
+ self.model_config["pipeline_cls"].save_lora_weights(
398
+ output_dir,
399
+ transformer_lora_layers=transformer_lora_layers_to_save,
400
+ )
401
+ else:
402
+ model.save_pretrained(os.path.join(output_dir, "transformer"))
403
+
404
+ # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
405
+ # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
406
+ # ensure the scheduler config is serialized as well.
407
+ self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
408
+
409
+ def load_model_hook(models, input_dir):
410
+ if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
411
+ while len(models) > 0:
412
+ model = models.pop()
413
+ if isinstance(
414
+ unwrap_model(self.state.accelerator, model),
415
+ type(unwrap_model(self.state.accelerator, self.transformer)),
416
+ ):
417
+ transformer_ = unwrap_model(self.state.accelerator, model)
418
+ else:
419
+ raise ValueError(
420
+ f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
421
+ )
422
+ else:
423
+ transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
424
+
425
+ if self.args.training_type == "lora":
426
+ transformer_ = transformer_cls_.from_pretrained(
427
+ self.args.pretrained_model_name_or_path, subfolder="transformer"
428
+ )
429
+ transformer_.add_adapter(transformer_lora_config)
430
+ lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
431
+ transformer_state_dict = {
432
+ f'{k.replace("transformer.", "")}': v
433
+ for k, v in lora_state_dict.items()
434
+ if k.startswith("transformer.")
435
+ }
436
+ incompatible_keys = set_peft_model_state_dict(
437
+ transformer_, transformer_state_dict, adapter_name="default"
438
+ )
439
+ if incompatible_keys is not None:
440
+ # check only for unexpected keys
441
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
442
+ if unexpected_keys:
443
+ logger.warning(
444
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
445
+ f" {unexpected_keys}. "
446
+ )
447
+ else:
448
+ transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
449
+
450
+ self.state.accelerator.register_save_state_pre_hook(save_model_hook)
451
+ self.state.accelerator.register_load_state_pre_hook(load_model_hook)
452
+
453
+ def prepare_optimizer(self) -> None:
454
+ logger.info("Initializing optimizer and lr scheduler")
455
+
456
+ self.state.train_epochs = self.args.train_epochs
457
+ self.state.train_steps = self.args.train_steps
458
+
459
+ # Make sure the trainable params are in float32
460
+ if self.args.training_type == "lora":
461
+ cast_training_params([self.transformer], dtype=torch.float32)
462
+
463
+ self.state.learning_rate = self.args.lr
464
+ if self.args.scale_lr:
465
+ self.state.learning_rate = (
466
+ self.state.learning_rate
467
+ * self.args.gradient_accumulation_steps
468
+ * self.args.batch_size
469
+ * self.state.accelerator.num_processes
470
+ )
471
+
472
+ transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
473
+ transformer_parameters_with_lr = {
474
+ "params": transformer_trainable_parameters,
475
+ "lr": self.state.learning_rate,
476
+ }
477
+ params_to_optimize = [transformer_parameters_with_lr]
478
+ self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters)
479
+
480
+ use_deepspeed_opt = (
481
+ self.state.accelerator.state.deepspeed_plugin is not None
482
+ and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
483
+ )
484
+ optimizer = get_optimizer(
485
+ params_to_optimize=params_to_optimize,
486
+ optimizer_name=self.args.optimizer,
487
+ learning_rate=self.state.learning_rate,
488
+ beta1=self.args.beta1,
489
+ beta2=self.args.beta2,
490
+ beta3=self.args.beta3,
491
+ epsilon=self.args.epsilon,
492
+ weight_decay=self.args.weight_decay,
493
+ use_8bit=self.args.use_8bit_bnb,
494
+ use_deepspeed=use_deepspeed_opt,
495
+ )
496
+
497
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
498
+ if self.state.train_steps is None:
499
+ self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
500
+ self.state.overwrote_max_train_steps = True
501
+
502
+ use_deepspeed_lr_scheduler = (
503
+ self.state.accelerator.state.deepspeed_plugin is not None
504
+ and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
505
+ )
506
+ total_training_steps = self.state.train_steps * self.state.accelerator.num_processes
507
+ num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes
508
+
509
+ if use_deepspeed_lr_scheduler:
510
+ from accelerate.utils import DummyScheduler
511
+
512
+ lr_scheduler = DummyScheduler(
513
+ name=self.args.lr_scheduler,
514
+ optimizer=optimizer,
515
+ total_num_steps=total_training_steps,
516
+ num_warmup_steps=num_warmup_steps,
517
+ )
518
+ else:
519
+ lr_scheduler = get_scheduler(
520
+ name=self.args.lr_scheduler,
521
+ optimizer=optimizer,
522
+ num_warmup_steps=num_warmup_steps,
523
+ num_training_steps=total_training_steps,
524
+ num_cycles=self.args.lr_num_cycles,
525
+ power=self.args.lr_power,
526
+ )
527
+
528
+ self.optimizer = optimizer
529
+ self.lr_scheduler = lr_scheduler
530
+
531
+ def prepare_for_training(self) -> None:
532
+ self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
533
+ self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
534
+ )
535
+
536
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
537
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
538
+ if self.state.overwrote_max_train_steps:
539
+ self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
540
+ # Afterwards we recalculate our number of training epochs
541
+ self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
542
+ self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
543
+
544
+ def prepare_trackers(self) -> None:
545
+ logger.info("Initializing trackers")
546
+
547
+ tracker_name = self.args.tracker_name or "finetrainers-experiment"
548
+ self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info())
549
+
550
+ def train(self) -> None:
551
+ logger.info("Starting training")
552
+
553
+
554
+ # Add these lines at the beginning
555
+ if hasattr(resource, 'RLIMIT_NOFILE'):
556
+ try:
557
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
558
+ logger.info(f"Current file descriptor limits in trainer: soft={soft}, hard={hard}")
559
+ # Try to increase to hard limit if possible
560
+ if soft < hard:
561
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
562
+ new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
563
+ logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
564
+ except Exception as e:
565
+ logger.warning(f"Could not check or update file descriptor limits: {e}")
566
+
567
+ memory_statistics = get_memory_statistics()
568
+ logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
569
+
570
+ if self.vae_config is None:
571
+ # If we've precomputed conditions and latents already, and are now re-using it, we will never load
572
+ # the VAE so self.vae_config will not be set. So, we need to load it here.
573
+ vae_cls = resolve_vae_cls_from_ckpt_path(
574
+ self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir
575
+ )
576
+ vae_config = vae_cls.load_config(
577
+ self.args.pretrained_model_name_or_path,
578
+ subfolder="vae",
579
+ revision=self.args.revision,
580
+ cache_dir=self.args.cache_dir,
581
+ )
582
+ self.vae_config = FrozenDict(**vae_config)
583
+
584
+ # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
585
+ # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
586
+ # ensure the scheduler config is serialized as well.
587
+ if self.args.training_type == "full-finetune":
588
+ self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler"))
589
+
590
+ self.state.train_batch_size = (
591
+ self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
592
+ )
593
+ info = {
594
+ "trainable parameters": self.state.num_trainable_parameters,
595
+ "total samples": len(self.dataset),
596
+ "train epochs": self.state.train_epochs,
597
+ "train steps": self.state.train_steps,
598
+ "batches per device": self.args.batch_size,
599
+ "total batches observed per epoch": len(self.dataloader),
600
+ "train batch size": self.state.train_batch_size,
601
+ "gradient accumulation steps": self.args.gradient_accumulation_steps,
602
+ }
603
+ logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
604
+
605
+ global_step = 0
606
+ first_epoch = 0
607
+ initial_global_step = 0
608
+
609
+ # Potentially load in the weights and states from a previous save
610
+ (
611
+ resume_from_checkpoint_path,
612
+ initial_global_step,
613
+ global_step,
614
+ first_epoch,
615
+ ) = get_latest_ckpt_path_to_resume_from(
616
+ resume_from_checkpoint=self.args.resume_from_checkpoint,
617
+ num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
618
+ output_dir=self.args.output_dir,
619
+ )
620
+ if resume_from_checkpoint_path:
621
+ self.state.accelerator.load_state(resume_from_checkpoint_path)
622
+
623
+ progress_bar = tqdm(
624
+ range(0, self.state.train_steps),
625
+ initial=initial_global_step,
626
+ desc="Training steps",
627
+ disable=not self.state.accelerator.is_local_main_process,
628
+ )
629
+
630
+ accelerator = self.state.accelerator
631
+ generator = torch.Generator(device=accelerator.device)
632
+ if self.args.seed is not None:
633
+ generator = generator.manual_seed(self.args.seed)
634
+ self.state.generator = generator
635
+
636
+ scheduler_sigmas = get_scheduler_sigmas(self.scheduler)
637
+ scheduler_sigmas = (
638
+ scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32)
639
+ if scheduler_sigmas is not None
640
+ else None
641
+ )
642
+ scheduler_alphas = get_scheduler_alphas(self.scheduler)
643
+ scheduler_alphas = (
644
+ scheduler_alphas.to(device=accelerator.device, dtype=torch.float32)
645
+ if scheduler_alphas is not None
646
+ else None
647
+ )
648
+
649
+ for epoch in range(first_epoch, self.state.train_epochs):
650
+ logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")
651
+
652
+ self.transformer.train()
653
+ models_to_accumulate = [self.transformer]
654
+ epoch_loss = 0.0
655
+ num_loss_updates = 0
656
+
657
+ for step, batch in enumerate(self.dataloader):
658
+ logger.debug(f"Starting step {step + 1}")
659
+ logs = {}
660
+
661
+ with accelerator.accumulate(models_to_accumulate):
662
+ if not self.args.precompute_conditions:
663
+ videos = batch["videos"]
664
+ prompts = batch["prompts"]
665
+ batch_size = len(prompts)
666
+
667
+ if self.args.caption_dropout_technique == "empty":
668
+ if random.random() < self.args.caption_dropout_p:
669
+ prompts = [""] * batch_size
670
+
671
+ latent_conditions = self.model_config["prepare_latents"](
672
+ vae=self.vae,
673
+ image_or_video=videos,
674
+ patch_size=self.transformer_config.patch_size,
675
+ patch_size_t=self.transformer_config.patch_size_t,
676
+ device=accelerator.device,
677
+ dtype=self.args.transformer_dtype,
678
+ generator=self.state.generator,
679
+ )
680
+ text_conditions = self.model_config["prepare_conditions"](
681
+ tokenizer=self.tokenizer,
682
+ text_encoder=self.text_encoder,
683
+ tokenizer_2=self.tokenizer_2,
684
+ text_encoder_2=self.text_encoder_2,
685
+ prompt=prompts,
686
+ device=accelerator.device,
687
+ dtype=self.args.transformer_dtype,
688
+ )
689
+ else:
690
+ latent_conditions = batch["latent_conditions"]
691
+ text_conditions = batch["text_conditions"]
692
+ latent_conditions["latents"] = DiagonalGaussianDistribution(
693
+ latent_conditions["latents"]
694
+ ).sample(self.state.generator)
695
+
696
+ # This method should only be called for precomputed latents.
697
+ # TODO(aryan): rename this in separate PR
698
+ latent_conditions = self.model_config["post_latent_preparation"](
699
+ vae_config=self.vae_config,
700
+ patch_size=self.transformer_config.patch_size,
701
+ patch_size_t=self.transformer_config.patch_size_t,
702
+ **latent_conditions,
703
+ )
704
+ align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype)
705
+ align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype)
706
+ batch_size = latent_conditions["latents"].shape[0]
707
+
708
+ latent_conditions = make_contiguous(latent_conditions)
709
+ text_conditions = make_contiguous(text_conditions)
710
+
711
+ if self.args.caption_dropout_technique == "zero":
712
+ if random.random() < self.args.caption_dropout_p:
713
+ text_conditions["prompt_embeds"].fill_(0)
714
+ text_conditions["prompt_attention_mask"].fill_(False)
715
+
716
+ # TODO(aryan): refactor later
717
+ if "pooled_prompt_embeds" in text_conditions:
718
+ text_conditions["pooled_prompt_embeds"].fill_(0)
719
+
720
+ sigmas = prepare_sigmas(
721
+ scheduler=self.scheduler,
722
+ sigmas=scheduler_sigmas,
723
+ batch_size=batch_size,
724
+ num_train_timesteps=self.scheduler.config.num_train_timesteps,
725
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
726
+ flow_logit_mean=self.args.flow_logit_mean,
727
+ flow_logit_std=self.args.flow_logit_std,
728
+ flow_mode_scale=self.args.flow_mode_scale,
729
+ device=accelerator.device,
730
+ generator=self.state.generator,
731
+ )
732
+ timesteps = (sigmas * 1000.0).long()
733
+
734
+ noise = torch.randn(
735
+ latent_conditions["latents"].shape,
736
+ generator=self.state.generator,
737
+ device=accelerator.device,
738
+ dtype=self.args.transformer_dtype,
739
+ )
740
+ sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
741
+
742
+ # TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of
743
+ # scheduler and calculate the noisy latents accordingly. Look into this later.
744
+ if "calculate_noisy_latents" in self.model_config.keys():
745
+ noisy_latents = self.model_config["calculate_noisy_latents"](
746
+ scheduler=self.scheduler,
747
+ noise=noise,
748
+ latents=latent_conditions["latents"],
749
+ timesteps=timesteps,
750
+ )
751
+ else:
752
+ # Default to flow-matching noise addition
753
+ noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise
754
+ noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype)
755
+
756
+ latent_conditions.update({"noisy_latents": noisy_latents})
757
+
758
+ weights = prepare_loss_weights(
759
+ scheduler=self.scheduler,
760
+ alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
761
+ sigmas=sigmas,
762
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
763
+ )
764
+ weights = expand_tensor_dims(weights, noise.ndim)
765
+
766
+ pred = self.model_config["forward_pass"](
767
+ transformer=self.transformer,
768
+ scheduler=self.scheduler,
769
+ timesteps=timesteps,
770
+ **latent_conditions,
771
+ **text_conditions,
772
+ )
773
+ target = prepare_target(
774
+ scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"]
775
+ )
776
+
777
+ loss = weights.float() * (pred["latents"].float() - target.float()).pow(2)
778
+ # Average loss across all but batch dimension
779
+ loss = loss.mean(list(range(1, loss.ndim)))
780
+ # Average loss across batch dimension
781
+ loss = loss.mean()
782
+ accelerator.backward(loss)
783
+
784
+ if accelerator.sync_gradients:
785
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
786
+ grad_norm = self.transformer.get_global_grad_norm()
787
+ # In some cases the grad norm may not return a float
788
+ if torch.is_tensor(grad_norm):
789
+ grad_norm = grad_norm.item()
790
+ else:
791
+ grad_norm = accelerator.clip_grad_norm_(
792
+ self.transformer.parameters(), self.args.max_grad_norm
793
+ )
794
+ if torch.is_tensor(grad_norm):
795
+ grad_norm = grad_norm.item()
796
+
797
+ logs["grad_norm"] = grad_norm
798
+
799
+ self.optimizer.step()
800
+ self.lr_scheduler.step()
801
+ self.optimizer.zero_grad()
802
+
803
+ # Checks if the accelerator has performed an optimization step behind the scenes
804
+ if accelerator.sync_gradients:
805
+ progress_bar.update(1)
806
+ global_step += 1
807
+
808
+ # Checkpointing
809
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
810
+ if global_step % self.args.checkpointing_steps == 0:
811
+ save_path = get_intermediate_ckpt_path(
812
+ checkpointing_limit=self.args.checkpointing_limit,
813
+ step=global_step,
814
+ output_dir=self.args.output_dir,
815
+ )
816
+ accelerator.save_state(save_path)
817
+
818
+ # Maybe run validation
819
+ should_run_validation = (
820
+ self.args.validation_every_n_steps is not None
821
+ and global_step % self.args.validation_every_n_steps == 0
822
+ )
823
+ if should_run_validation:
824
+ self.validate(global_step)
825
+
826
+ loss_item = loss.detach().item()
827
+ epoch_loss += loss_item
828
+ num_loss_updates += 1
829
+ logs["step_loss"] = loss_item
830
+ logs["lr"] = self.lr_scheduler.get_last_lr()[0]
831
+ progress_bar.set_postfix(logs)
832
+ accelerator.log(logs, step=global_step)
833
+
834
+ if global_step % 100 == 0: # Every 100 steps
835
+ # Force garbage collection to clean up any lingering resources
836
+ gc.collect()
837
+
838
+ if global_step >= self.state.train_steps:
839
+ break
840
+
841
+
842
+
843
+ if num_loss_updates > 0:
844
+ epoch_loss /= num_loss_updates
845
+ accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
846
+ memory_statistics = get_memory_statistics()
847
+ logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
848
+
849
+ # Maybe run validation
850
+ should_run_validation = (
851
+ self.args.validation_every_n_epochs is not None
852
+ and (epoch + 1) % self.args.validation_every_n_epochs == 0
853
+ )
854
+ if should_run_validation:
855
+ self.validate(global_step)
856
+
857
+ if epoch % 3 == 0: # Every 3 epochs
858
+ logger.info("Performing periodic resource cleanup")
859
+ free_memory()
860
+ gc.collect()
861
+ torch.cuda.empty_cache()
862
+ torch.cuda.synchronize(accelerator.device)
863
+
864
+ accelerator.wait_for_everyone()
865
+ if accelerator.is_main_process:
866
+ transformer = unwrap_model(accelerator, self.transformer)
867
+
868
+ if self.args.training_type == "lora":
869
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
870
+
871
+ self.model_config["pipeline_cls"].save_lora_weights(
872
+ save_directory=self.args.output_dir,
873
+ transformer_lora_layers=transformer_lora_layers,
874
+ )
875
+ else:
876
+ transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer"))
877
+ accelerator.wait_for_everyone()
878
+ self.validate(step=global_step, final_validation=True)
879
+
880
+ if accelerator.is_main_process:
881
+ if self.args.push_to_hub:
882
+ upload_folder(
883
+ repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"]
884
+ )
885
+
886
+ self._delete_components()
887
+ memory_statistics = get_memory_statistics()
888
+ logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
889
+
890
+ accelerator.end_training()
891
+
892
+ def validate(self, step: int, final_validation: bool = False) -> None:
893
+ logger.info("Starting validation")
894
+
895
+ accelerator = self.state.accelerator
896
+ num_validation_samples = len(self.args.validation_prompts)
897
+
898
+ if num_validation_samples == 0:
899
+ logger.warning("No validation samples found. Skipping validation.")
900
+ if accelerator.is_main_process:
901
+ if self.args.push_to_hub:
902
+ save_model_card(
903
+ args=self.args,
904
+ repo_id=self.state.repo_id,
905
+ videos=None,
906
+ validation_prompts=None,
907
+ )
908
+ return
909
+
910
+ self.transformer.eval()
911
+
912
+ memory_statistics = get_memory_statistics()
913
+ logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
914
+
915
+ pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation)
916
+
917
+ all_processes_artifacts = []
918
+ prompts_to_filenames = {}
919
+ for i in range(num_validation_samples):
920
+ # Skip current validation on all processes but one
921
+ if i % accelerator.num_processes != accelerator.process_index:
922
+ continue
923
+
924
+ prompt = self.args.validation_prompts[i]
925
+ image = self.args.validation_images[i]
926
+ video = self.args.validation_videos[i]
927
+ height = self.args.validation_heights[i]
928
+ width = self.args.validation_widths[i]
929
+ num_frames = self.args.validation_num_frames[i]
930
+ frame_rate = self.args.validation_frame_rate
931
+ if image is not None:
932
+ image = load_image(image)
933
+ if video is not None:
934
+ video = load_video(video)
935
+
936
+ logger.debug(
937
+ f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
938
+ main_process_only=False,
939
+ )
940
+ validation_artifacts = self.model_config["validation"](
941
+ pipeline=pipeline,
942
+ prompt=prompt,
943
+ image=image,
944
+ video=video,
945
+ height=height,
946
+ width=width,
947
+ num_frames=num_frames,
948
+ frame_rate=frame_rate,
949
+ num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
950
+ generator=torch.Generator(device=accelerator.device).manual_seed(
951
+ self.args.seed if self.args.seed is not None else 0
952
+ ),
953
+ # todo support passing `fps` for supported pipelines.
954
+ )
955
+
956
+ prompt_filename = string_to_filename(prompt)[:25]
957
+ artifacts = {
958
+ "image": {"type": "image", "value": image},
959
+ "video": {"type": "video", "value": video},
960
+ }
961
+ for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
962
+ if artifact_value:
963
+ artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
964
+ logger.debug(
965
+ f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
966
+ main_process_only=False,
967
+ )
968
+
969
+ for index, (key, value) in enumerate(list(artifacts.items())):
970
+ artifact_type = value["type"]
971
+ artifact_value = value["value"]
972
+ if artifact_type not in ["image", "video"] or artifact_value is None:
973
+ continue
974
+
975
+ extension = "png" if artifact_type == "image" else "mp4"
976
+ filename = "validation-" if not final_validation else "final-"
977
+ filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}"
978
+ if accelerator.is_main_process and extension == "mp4":
979
+ prompts_to_filenames[prompt] = filename
980
+ filename = os.path.join(self.args.output_dir, filename)
981
+
982
+ if artifact_type == "image" and artifact_value:
983
+ logger.debug(f"Saving image to {filename}")
984
+ artifact_value.save(filename)
985
+ artifact_value = wandb.Image(filename)
986
+ elif artifact_type == "video" and artifact_value:
987
+ logger.debug(f"Saving video to {filename}")
988
+ # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
989
+ export_to_video(artifact_value, filename, fps=frame_rate)
990
+ artifact_value = wandb.Video(filename, caption=prompt)
991
+
992
+ all_processes_artifacts.append(artifact_value)
993
+
994
+ all_artifacts = gather_object(all_processes_artifacts)
995
+
996
+ if accelerator.is_main_process:
997
+ tracker_key = "final" if final_validation else "validation"
998
+ for tracker in accelerator.trackers:
999
+ if tracker.name == "wandb":
1000
+ artifact_log_dict = {}
1001
+
1002
+ image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
1003
+ if len(image_artifacts) > 0:
1004
+ artifact_log_dict["images"] = image_artifacts
1005
+ video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
1006
+ if len(video_artifacts) > 0:
1007
+ artifact_log_dict["videos"] = video_artifacts
1008
+ tracker.log({tracker_key: artifact_log_dict}, step=step)
1009
+
1010
+ if self.args.push_to_hub and final_validation:
1011
+ video_filenames = list(prompts_to_filenames.values())
1012
+ prompts = list(prompts_to_filenames.keys())
1013
+ save_model_card(
1014
+ args=self.args,
1015
+ repo_id=self.state.repo_id,
1016
+ videos=video_filenames,
1017
+ validation_prompts=prompts,
1018
+ )
1019
+
1020
+ # Remove all hooks that might have been added during pipeline initialization to the models
1021
+ pipeline.remove_all_hooks()
1022
+ del pipeline
1023
+
1024
+ accelerator.wait_for_everyone()
1025
+
1026
+ free_memory()
1027
+ memory_statistics = get_memory_statistics()
1028
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
1029
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
1030
+
1031
+ if not final_validation:
1032
+ self.transformer.train()
1033
+
1034
+ def evaluate(self) -> None:
1035
+ raise NotImplementedError("Evaluation has not been implemented yet.")
1036
+
1037
+ def _init_distributed(self) -> None:
1038
+ logging_dir = Path(self.args.output_dir, self.args.logging_dir)
1039
+ project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
1040
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
1041
+ init_process_group_kwargs = InitProcessGroupKwargs(
1042
+ backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
1043
+ )
1044
+ report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
1045
+
1046
+ accelerator = Accelerator(
1047
+ project_config=project_config,
1048
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
1049
+ log_with=report_to,
1050
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
1051
+ )
1052
+
1053
+ # Disable AMP for MPS.
1054
+ if torch.backends.mps.is_available():
1055
+ accelerator.native_amp = False
1056
+
1057
+ self.state.accelerator = accelerator
1058
+
1059
+ if self.args.seed is not None:
1060
+ self.state.seed = self.args.seed
1061
+ set_seed(self.args.seed)
1062
+
1063
+ def _init_logging(self) -> None:
1064
+ logging.basicConfig(
1065
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1066
+ datefmt="%m/%d/%Y %H:%M:%S",
1067
+ level=FINETRAINERS_LOG_LEVEL,
1068
+ )
1069
+ if self.state.accelerator.is_local_main_process:
1070
+ transformers.utils.logging.set_verbosity_warning()
1071
+ diffusers.utils.logging.set_verbosity_info()
1072
+ else:
1073
+ transformers.utils.logging.set_verbosity_error()
1074
+ diffusers.utils.logging.set_verbosity_error()
1075
+
1076
+ logger.info("Initialized FineTrainers")
1077
+ logger.info(self.state.accelerator.state, main_process_only=False)
1078
+
1079
+ def _init_directories_and_repositories(self) -> None:
1080
+ if self.state.accelerator.is_main_process:
1081
+ self.args.output_dir = Path(self.args.output_dir)
1082
+ self.args.output_dir.mkdir(parents=True, exist_ok=True)
1083
+ self.state.output_dir = Path(self.args.output_dir)
1084
+
1085
+ if self.args.push_to_hub:
1086
+ repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
1087
+ self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
1088
+
1089
+ def _init_config_options(self) -> None:
1090
+ # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1091
+ if self.args.allow_tf32 and torch.cuda.is_available():
1092
+ torch.backends.cuda.matmul.allow_tf32 = True
1093
+
1094
+ def _move_components_to_device(self):
1095
+ if self.text_encoder is not None:
1096
+ self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
1097
+ if self.text_encoder_2 is not None:
1098
+ self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
1099
+ if self.text_encoder_3 is not None:
1100
+ self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
1101
+ if self.transformer is not None:
1102
+ self.transformer = self.transformer.to(self.state.accelerator.device)
1103
+ if self.unet is not None:
1104
+ self.unet = self.unet.to(self.state.accelerator.device)
1105
+ if self.vae is not None:
1106
+ self.vae = self.vae.to(self.state.accelerator.device)
1107
+
1108
+ def _get_load_components_kwargs(self) -> Dict[str, Any]:
1109
+ load_component_kwargs = {
1110
+ "text_encoder_dtype": self.args.text_encoder_dtype,
1111
+ "text_encoder_2_dtype": self.args.text_encoder_2_dtype,
1112
+ "text_encoder_3_dtype": self.args.text_encoder_3_dtype,
1113
+ "transformer_dtype": self.args.transformer_dtype,
1114
+ "vae_dtype": self.args.vae_dtype,
1115
+ "shift": self.args.flow_shift,
1116
+ "revision": self.args.revision,
1117
+ "cache_dir": self.args.cache_dir,
1118
+ }
1119
+ if self.args.pretrained_model_name_or_path is not None:
1120
+ load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
1121
+ return load_component_kwargs
1122
+
1123
+ def _set_components(self, components: Dict[str, Any]) -> None:
1124
+ # Set models
1125
+ self.tokenizer = components.get("tokenizer", self.tokenizer)
1126
+ self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
1127
+ self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
1128
+ self.text_encoder = components.get("text_encoder", self.text_encoder)
1129
+ self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
1130
+ self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
1131
+ self.transformer = components.get("transformer", self.transformer)
1132
+ self.unet = components.get("unet", self.unet)
1133
+ self.vae = components.get("vae", self.vae)
1134
+ self.scheduler = components.get("scheduler", self.scheduler)
1135
+
1136
+ # Set configs
1137
+ self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
1138
+ self.vae_config = self.vae.config if self.vae is not None else self.vae_config
1139
+
1140
+ def _delete_components(self) -> None:
1141
+ self.tokenizer = None
1142
+ self.tokenizer_2 = None
1143
+ self.tokenizer_3 = None
1144
+ self.text_encoder = None
1145
+ self.text_encoder_2 = None
1146
+ self.text_encoder_3 = None
1147
+ self.transformer = None
1148
+ self.unet = None
1149
+ self.vae = None
1150
+ self.scheduler = None
1151
+ free_memory()
1152
+ torch.cuda.synchronize(self.state.accelerator.device)
1153
+
1154
+ def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
1155
+ accelerator = self.state.accelerator
1156
+ if not final_validation:
1157
+ pipeline = self.model_config["initialize_pipeline"](
1158
+ model_id=self.args.pretrained_model_name_or_path,
1159
+ tokenizer=self.tokenizer,
1160
+ text_encoder=self.text_encoder,
1161
+ tokenizer_2=self.tokenizer_2,
1162
+ text_encoder_2=self.text_encoder_2,
1163
+ transformer=unwrap_model(accelerator, self.transformer),
1164
+ vae=self.vae,
1165
+ device=accelerator.device,
1166
+ revision=self.args.revision,
1167
+ cache_dir=self.args.cache_dir,
1168
+ enable_slicing=self.args.enable_slicing,
1169
+ enable_tiling=self.args.enable_tiling,
1170
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
1171
+ is_training=True,
1172
+ )
1173
+ else:
1174
+ self._delete_components()
1175
+
1176
+ # Load the transformer weights from the final checkpoint if performing full-finetune
1177
+ transformer = None
1178
+ if self.args.training_type == "full-finetune":
1179
+ transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"]
1180
+
1181
+ pipeline = self.model_config["initialize_pipeline"](
1182
+ model_id=self.args.pretrained_model_name_or_path,
1183
+ transformer=transformer,
1184
+ device=accelerator.device,
1185
+ revision=self.args.revision,
1186
+ cache_dir=self.args.cache_dir,
1187
+ enable_slicing=self.args.enable_slicing,
1188
+ enable_tiling=self.args.enable_tiling,
1189
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
1190
+ is_training=False,
1191
+ )
1192
+
1193
+ # Load the LoRA weights if performing LoRA finetuning
1194
+ if self.args.training_type == "lora":
1195
+ pipeline.load_lora_weights(self.args.output_dir)
1196
+
1197
+ return pipeline
1198
+
1199
+ def _disable_grad_for_components(self, components: List[torch.nn.Module]):
1200
+ for component in components:
1201
+ if component is not None:
1202
+ component.requires_grad_(False)
1203
+
1204
+ def _enable_grad_for_components(self, components: List[torch.nn.Module]):
1205
+ for component in components:
1206
+ if component is not None:
1207
+ component.requires_grad_(True)
1208
+
1209
+ def _get_training_info(self) -> dict:
1210
+ args = self.args.to_dict()
1211
+
1212
+ training_args = args.get("training_arguments", {})
1213
+ training_type = training_args.get("training_type", "")
1214
+
1215
+ # LoRA/non-LoRA stuff.
1216
+ if training_type == "full-finetune":
1217
+ filtered_training_args = {
1218
+ k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"}
1219
+ }
1220
+ else:
1221
+ filtered_training_args = training_args
1222
+
1223
+ # Diffusion/flow stuff.
1224
+ diffusion_args = args.get("diffusion_arguments", {})
1225
+ scheduler_name = self.scheduler.__class__.__name__
1226
+ if scheduler_name != "FlowMatchEulerDiscreteScheduler":
1227
+ filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
1228
+ else:
1229
+ filtered_diffusion_args = diffusion_args
1230
+
1231
+ # Rest of the stuff.
1232
+ updated_training_info = args.copy()
1233
+ updated_training_info["training_arguments"] = filtered_training_args
1234
+ updated_training_info["diffusion_arguments"] = filtered_diffusion_args
1235
+ return updated_training_info
finetrainers/trainer.py CHANGED
@@ -7,7 +7,7 @@ import random
7
  from datetime import datetime, timedelta
8
  from pathlib import Path
9
  from typing import Any, Dict, List
10
-
11
  import diffusers
12
  import torch
13
  import torch.backends
 
7
  from datetime import datetime, timedelta
8
  from pathlib import Path
9
  from typing import Any, Dict, List
10
+ import resource
11
  import diffusers
12
  import torch
13
  import torch.backends
vms/services/trainer.py CHANGED
@@ -153,7 +153,7 @@ class TrainingService:
153
  # Make sure we have all keys (in case structure changed)
154
  merged_state = default_state.copy()
155
  merged_state.update(saved_state)
156
- logger.info(f"Successfully loaded UI state from {ui_state_file}")
157
  return merged_state
158
  except json.JSONDecodeError as e:
159
  logger.error(f"Error parsing UI state JSON: {str(e)}")
@@ -637,49 +637,68 @@ class TrainingService:
637
  return False
638
 
639
  def recover_interrupted_training(self) -> Dict[str, Any]:
640
- """Attempt to recover interrupted training
641
-
642
- Returns:
643
- Dict with recovery status and UI updates
644
- """
645
- status = self.get_status()
646
- ui_updates = {}
647
-
648
- # Check for any checkpoints, even if status doesn't indicate training
649
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
650
- has_checkpoints = len(checkpoints) > 0
651
-
652
- # If status indicates training but process isn't running, or if we have checkpoints
653
- # and no active training process, try to recover
654
- if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
655
- (has_checkpoints and not self.is_training_running()):
656
 
657
- logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
 
 
 
 
658
 
659
- # Get the latest checkpoint
660
- last_session = self.load_session()
 
661
 
662
- if not last_session:
663
- logger.warning("No session data found for recovery, but will check for checkpoints")
664
- # Try to create a default session based on UI state if we have checkpoints
665
- if has_checkpoints:
666
- ui_state = self.load_ui_state()
667
- # Create a default session using UI state values
668
- last_session = {
669
- "params": {
670
- "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
671
- "lora_rank": ui_state.get("lora_rank", "128"),
672
- "lora_alpha": ui_state.get("lora_alpha", "128"),
673
- "num_epochs": ui_state.get("num_epochs", 70),
674
- "batch_size": ui_state.get("batch_size", 1),
675
- "learning_rate": ui_state.get("learning_rate", 3e-5),
676
- "save_iterations": ui_state.get("save_iterations", 500),
677
- "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
678
- "repo_id": "" # Default empty repo ID
 
 
 
 
 
 
 
 
 
 
 
679
  }
680
- }
681
- logger.info("Created default session from UI state for recovery")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  else:
 
683
  # Set buttons for no active training
684
  ui_updates = {
685
  "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
@@ -687,116 +706,98 @@ class TrainingService:
687
  "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
688
  "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
689
  }
690
- return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
691
-
692
- # Find the latest checkpoint if we have checkpoints
693
- latest_checkpoint = None
694
- checkpoint_step = 0
695
-
696
- if has_checkpoints:
697
- latest_checkpoint = max(checkpoints, key=os.path.getmtime)
698
- checkpoint_step = int(latest_checkpoint.name.split("-")[1])
699
- logger.info(f"Found checkpoint at step {checkpoint_step}")
700
- else:
701
- logger.warning("No checkpoints found for recovery")
702
- # Set buttons for no active training
703
- ui_updates = {
704
- "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
705
- "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
706
- "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
707
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
708
- }
709
- return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
710
-
711
- # Extract parameters from the saved session (not current UI state)
712
- # This ensures we use the original training parameters
713
- params = last_session.get('params', {})
714
-
715
- # Map internal model type back to display name for UI
716
- # This is the key fix for the "ltx_video" vs "LTX-Video (LoRA)" mismatch
717
- model_type_internal = params.get('model_type')
718
- model_type_display = model_type_internal
719
-
720
- # Find the display name that maps to our internal model type
721
- for display_name, internal_name in MODEL_TYPES.items():
722
- if internal_name == model_type_internal:
723
- model_type_display = display_name
724
- logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
725
- break
726
-
727
- # Add UI updates to restore the training parameters in the UI
728
- # This shows the user what values are being used for the resumed training
729
- ui_updates.update({
730
- "model_type": model_type_display, # Use the display name for the UI dropdown
731
- "lora_rank": params.get('lora_rank', "128"),
732
- "lora_alpha": params.get('lora_alpha', "128"),
733
- "num_epochs": params.get('num_epochs', 70),
734
- "batch_size": params.get('batch_size', 1),
735
- "learning_rate": params.get('learning_rate', 3e-5),
736
- "save_iterations": params.get('save_iterations', 500),
737
- "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
738
- })
739
-
740
- # Check if we should auto-recover (immediate restart)
741
- auto_recover = True # Always auto-recover on startup
742
-
743
- if auto_recover:
744
- # Rest of the auto-recovery code remains unchanged
745
- try:
746
- # Use the internal model_type for the actual training
747
- # But keep model_type_display for the UI
748
- result = self.start_training(
749
- model_type=model_type_internal,
750
- lora_rank=params.get('lora_rank', "128"),
751
- lora_alpha=params.get('lora_alpha', "128"),
752
- num_epochs=params.get('num_epochs', 70),
753
- batch_size=params.get('batch_size', 1),
754
- learning_rate=params.get('learning_rate', 3e-5),
755
- save_iterations=params.get('save_iterations', 500),
756
- repo_id=params.get('repo_id', ''),
757
- preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
758
- resume_from_checkpoint=str(latest_checkpoint)
759
- )
760
-
761
- # Set buttons for active training
762
- ui_updates.update({
763
- "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
764
- "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
765
- "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
766
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
767
- })
768
-
769
- return {
770
- "status": "recovered",
771
- "message": f"Training resumed from checkpoint {checkpoint_step}",
772
- "result": result,
773
- "ui_updates": ui_updates
774
- }
775
- except Exception as e:
776
- logger.error(f"Failed to auto-resume training: {str(e)}")
777
- # Set buttons for manual recovery
778
- ui_updates.update({
779
- "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
780
- "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
781
- "delete_checkpoints_btn": {"interactive": True, "variant": "stop", "value": "Delete All Checkpoints"},
782
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
783
- })
784
- return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
785
- else:
786
- # Set up UI for manual recovery
787
- ui_updates.update({
788
- "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
789
- "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
790
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
791
- })
792
- return {"status": "ready_to_recover", "message": f"Ready to resume from checkpoint {checkpoint_step}", "ui_updates": ui_updates}
793
-
794
  elif self.is_training_running():
795
  # Process is still running, set buttons accordingly
796
  ui_updates = {
797
  "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"},
798
  "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
799
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
 
800
  }
801
  return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
802
  else:
@@ -805,10 +806,11 @@ class TrainingService:
805
  ui_updates = {
806
  "start_btn": {"interactive": True, "variant": "primary", "value": button_text},
807
  "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
808
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
 
809
  }
810
  return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
811
-
812
  def delete_all_checkpoints(self) -> str:
813
  """Delete all checkpoints in the output directory.
814
 
 
153
  # Make sure we have all keys (in case structure changed)
154
  merged_state = default_state.copy()
155
  merged_state.update(saved_state)
156
+ #logger.info(f"Successfully loaded UI state from {ui_state_file}")
157
  return merged_state
158
  except json.JSONDecodeError as e:
159
  logger.error(f"Error parsing UI state JSON: {str(e)}")
 
637
  return False
638
 
639
  def recover_interrupted_training(self) -> Dict[str, Any]:
640
+ """Attempt to recover interrupted training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
+ Returns:
643
+ Dict with recovery status and UI updates
644
+ """
645
+ status = self.get_status()
646
+ ui_updates = {}
647
 
648
+ # Check for any checkpoints, even if status doesn't indicate training
649
+ checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
650
+ has_checkpoints = len(checkpoints) > 0
651
 
652
+ # If status indicates training but process isn't running, or if we have checkpoints
653
+ # and no active training process, try to recover
654
+ if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
655
+ (has_checkpoints and not self.is_training_running()):
656
+
657
+ logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
658
+
659
+ # Get the latest checkpoint
660
+ last_session = self.load_session()
661
+
662
+ if not last_session:
663
+ logger.warning("No session data found for recovery, but will check for checkpoints")
664
+ # Try to create a default session based on UI state if we have checkpoints
665
+ if has_checkpoints:
666
+ ui_state = self.load_ui_state()
667
+ # Create a default session using UI state values
668
+ last_session = {
669
+ "params": {
670
+ "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
671
+ "lora_rank": ui_state.get("lora_rank", "128"),
672
+ "lora_alpha": ui_state.get("lora_alpha", "128"),
673
+ "num_epochs": ui_state.get("num_epochs", 70),
674
+ "batch_size": ui_state.get("batch_size", 1),
675
+ "learning_rate": ui_state.get("learning_rate", 3e-5),
676
+ "save_iterations": ui_state.get("save_iterations", 500),
677
+ "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
678
+ "repo_id": "" # Default empty repo ID
679
+ }
680
  }
681
+ logger.info("Created default session from UI state for recovery")
682
+ else:
683
+ # Set buttons for no active training
684
+ ui_updates = {
685
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
686
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
687
+ "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
688
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
689
+ }
690
+ return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
691
+
692
+ # Find the latest checkpoint if we have checkpoints
693
+ latest_checkpoint = None
694
+ checkpoint_step = 0
695
+
696
+ if has_checkpoints:
697
+ latest_checkpoint = max(checkpoints, key=os.path.getmtime)
698
+ checkpoint_step = int(latest_checkpoint.name.split("-")[1])
699
+ logger.info(f"Found checkpoint at step {checkpoint_step}")
700
  else:
701
+ logger.warning("No checkpoints found for recovery")
702
  # Set buttons for no active training
703
  ui_updates = {
704
  "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
 
706
  "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
707
  "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
708
  }
709
+ return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
710
+
711
+ # Extract parameters from the saved session (not current UI state)
712
+ # This ensures we use the original training parameters
713
+ params = last_session.get('params', {})
714
+
715
+ # Map internal model type back to display name for UI
716
+ # This is the key fix for the "ltx_video" vs "LTX-Video (LoRA)" mismatch
717
+ model_type_internal = params.get('model_type')
718
+ model_type_display = model_type_internal
719
+
720
+ # Find the display name that maps to our internal model type
721
+ for display_name, internal_name in MODEL_TYPES.items():
722
+ if internal_name == model_type_internal:
723
+ model_type_display = display_name
724
+ logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
725
+ break
726
+
727
+ # Add UI updates to restore the training parameters in the UI
728
+ # This shows the user what values are being used for the resumed training
729
+ ui_updates.update({
730
+ "model_type": model_type_display, # Use the display name for the UI dropdown
731
+ "lora_rank": params.get('lora_rank', "128"),
732
+ "lora_alpha": params.get('lora_alpha', "128"),
733
+ "num_epochs": params.get('num_epochs', 70),
734
+ "batch_size": params.get('batch_size', 1),
735
+ "learning_rate": params.get('learning_rate', 3e-5),
736
+ "save_iterations": params.get('save_iterations', 500),
737
+ "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
738
+ })
739
+
740
+ # Check if we should auto-recover (immediate restart)
741
+ auto_recover = True # Always auto-recover on startup
742
+
743
+ if auto_recover:
744
+ # Rest of the auto-recovery code remains unchanged
745
+ try:
746
+ # Use the internal model_type for the actual training
747
+ # But keep model_type_display for the UI
748
+ result = self.start_training(
749
+ model_type=model_type_internal,
750
+ lora_rank=params.get('lora_rank', "128"),
751
+ lora_alpha=params.get('lora_alpha', "128"),
752
+ num_epochs=params.get('num_epochs', 70),
753
+ batch_size=params.get('batch_size', 1),
754
+ learning_rate=params.get('learning_rate', 3e-5),
755
+ save_iterations=params.get('save_iterations', 500),
756
+ repo_id=params.get('repo_id', ''),
757
+ preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
758
+ resume_from_checkpoint=str(latest_checkpoint)
759
+ )
760
+
761
+ # Set buttons for active training
762
+ ui_updates.update({
763
+ "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
764
+ "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
765
+ "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
766
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
767
+ })
768
+
769
+ return {
770
+ "status": "recovered",
771
+ "message": f"Training resumed from checkpoint {checkpoint_step}",
772
+ "result": result,
773
+ "ui_updates": ui_updates
774
+ }
775
+ except Exception as e:
776
+ logger.error(f"Failed to auto-resume training: {str(e)}")
777
+ # Set buttons for manual recovery
778
+ ui_updates.update({
779
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
780
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
781
+ "delete_checkpoints_btn": {"interactive": True, "variant": "stop", "value": "Delete All Checkpoints"},
782
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
783
+ })
784
+ return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
785
+ else:
786
+ # Set up UI for manual recovery
787
+ ui_updates.update({
788
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
789
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
790
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
791
+ })
792
+ return {"status": "ready_to_recover", "message": f"Ready to resume from checkpoint {checkpoint_step}", "ui_updates": ui_updates}
793
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  elif self.is_training_running():
795
  # Process is still running, set buttons accordingly
796
  ui_updates = {
797
  "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"},
798
  "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
799
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False},
800
+ "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
801
  }
802
  return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
803
  else:
 
806
  ui_updates = {
807
  "start_btn": {"interactive": True, "variant": "primary", "value": button_text},
808
  "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
809
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False},
810
+ "delete_checkpoints_btn": {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
811
  }
812
  return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
813
+
814
  def delete_all_checkpoints(self) -> str:
815
  """Delete all checkpoints in the output directory.
816
 
vms/ui/video_trainer_ui.py CHANGED
@@ -31,6 +31,10 @@ class VideoTrainerUI:
31
 
32
  # Recovery status from any interrupted training
33
  recovery_result = self.trainer.recover_interrupted_training()
 
 
 
 
34
  self.recovery_status = recovery_result.get("status", "unknown")
35
  self.ui_updates = recovery_result.get("ui_updates", {})
36
 
 
31
 
32
  # Recovery status from any interrupted training
33
  recovery_result = self.trainer.recover_interrupted_training()
34
+ # Add null check for recovery_result
35
+ if recovery_result is None:
36
+ recovery_result = {"status": "unknown", "ui_updates": {}}
37
+
38
  self.recovery_status = recovery_result.get("status", "unknown")
39
  self.ui_updates = recovery_result.get("ui_updates", {})
40