jbilcke commited on
Commit
57737a0
·
1 Parent(s): 02e94ba

upgrade Finetrainers

Browse files
finetrainers/models/cogvideox/base_specification.py CHANGED
@@ -299,7 +299,7 @@ class CogVideoXModelSpecification(ModelSpecification):
299
  latents = posterior.sample(generator=generator)
300
  del posterior
301
 
302
- if not self.vae_config.invert_scale_latents:
303
  latents = latents * self.vae_config.scaling_factor
304
 
305
  if patch_size_t is not None:
 
299
  latents = posterior.sample(generator=generator)
300
  del posterior
301
 
302
+ if not getattr(self.vae_config, "invert_scale_latents", False):
303
  latents = latents * self.vae_config.scaling_factor
304
 
305
  if patch_size_t is not None:
finetrainers/models/ltx_video/base_specification.py CHANGED
@@ -336,8 +336,8 @@ class LTXVideoModelSpecification(ModelSpecification):
336
  latents = self._pack_latents(latents, patch_size, patch_size_t)
337
  noise = self._pack_latents(noise, patch_size, patch_size_t)
338
  noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
339
-
340
  sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
 
341
 
342
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
343
 
@@ -352,7 +352,6 @@ class LTXVideoModelSpecification(ModelSpecification):
352
  vae_spatial_compression_ratio,
353
  vae_spatial_compression_ratio,
354
  ]
355
- timesteps = (sigmas * 1000.0).long()
356
 
357
  pred = transformer(
358
  **latent_model_conditions,
@@ -444,9 +443,9 @@ class LTXVideoModelSpecification(ModelSpecification):
444
  latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
445
  ) -> torch.Tensor:
446
  # Normalize latents across the channel dimension [B, C, F, H, W]
447
- latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
448
- latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
449
- latents = (latents - latents_mean) * scaling_factor / latents_std
450
  return latents
451
 
452
  @staticmethod
 
336
  latents = self._pack_latents(latents, patch_size, patch_size_t)
337
  noise = self._pack_latents(noise, patch_size, patch_size_t)
338
  noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
 
339
  sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
340
+ timesteps = (sigmas * 1000.0).long()
341
 
342
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
343
 
 
352
  vae_spatial_compression_ratio,
353
  vae_spatial_compression_ratio,
354
  ]
 
355
 
356
  pred = transformer(
357
  **latent_model_conditions,
 
443
  latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
444
  ) -> torch.Tensor:
445
  # Normalize latents across the channel dimension [B, C, F, H, W]
446
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
447
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
448
+ latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents)
449
  return latents
450
 
451
  @staticmethod
finetrainers/models/wan/base_specification.py CHANGED
@@ -39,7 +39,7 @@ class WanLatentEncodeProcessor(ProcessorMixin):
39
  def __init__(self, output_names: List[str]):
40
  super().__init__()
41
  self.output_names = output_names
42
- assert len(self.output_names) == 1
43
 
44
  def forward(
45
  self,
@@ -72,7 +72,10 @@ class WanLatentEncodeProcessor(ProcessorMixin):
72
  moments = vae._encode(video)
73
  latents = moments.to(dtype=dtype)
74
 
75
- return {self.output_names[0]: latents}
 
 
 
76
 
77
 
78
  class WanModelSpecification(ModelSpecification):
@@ -108,7 +111,7 @@ class WanModelSpecification(ModelSpecification):
108
  if condition_model_processors is None:
109
  condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
110
  if latent_model_processors is None:
111
- latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
112
 
113
  self.condition_model_processors = condition_model_processors
114
  self.latent_model_processors = latent_model_processors
@@ -266,7 +269,10 @@ class WanModelSpecification(ModelSpecification):
266
  "image": image,
267
  "video": video,
268
  "generator": generator,
269
- "compute_posterior": compute_posterior,
 
 
 
270
  **kwargs,
271
  }
272
  input_keys = set(conditions.keys())
@@ -284,20 +290,29 @@ class WanModelSpecification(ModelSpecification):
284
  compute_posterior: bool = True,
285
  **kwargs,
286
  ) -> Tuple[torch.Tensor, ...]:
 
287
  if compute_posterior:
288
  latents = latent_model_conditions.pop("latents")
289
  else:
290
- posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
 
 
 
 
 
 
 
 
 
291
  latents = posterior.sample(generator=generator)
292
  del posterior
293
 
294
  noise = torch.zeros_like(latents).normal_(generator=generator)
295
  noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
 
296
 
297
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
298
 
299
- timesteps = (sigmas.flatten() * 1000.0).long()
300
-
301
  pred = transformer(
302
  **latent_model_conditions,
303
  **condition_model_conditions,
@@ -367,3 +382,12 @@ class WanModelSpecification(ModelSpecification):
367
  transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
368
  if scheduler is not None:
369
  scheduler.save_pretrained(os.path.join(directory, "scheduler"))
 
 
 
 
 
 
 
 
 
 
39
  def __init__(self, output_names: List[str]):
40
  super().__init__()
41
  self.output_names = output_names
42
+ assert len(self.output_names) == 3
43
 
44
  def forward(
45
  self,
 
72
  moments = vae._encode(video)
73
  latents = moments.to(dtype=dtype)
74
 
75
+ latents_mean = torch.tensor(vae.config.latents_mean)
76
+ latents_std = 1.0 / torch.tensor(vae.config.latents_std)
77
+
78
+ return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std}
79
 
80
 
81
  class WanModelSpecification(ModelSpecification):
 
111
  if condition_model_processors is None:
112
  condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
113
  if latent_model_processors is None:
114
+ latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
115
 
116
  self.condition_model_processors = condition_model_processors
117
  self.latent_model_processors = latent_model_processors
 
269
  "image": image,
270
  "video": video,
271
  "generator": generator,
272
+ # We must force this to False because the latent normalization should be done before
273
+ # the posterior is computed. The VAE does not handle this any more:
274
+ # https://github.com/huggingface/diffusers/pull/10998
275
+ "compute_posterior": False,
276
  **kwargs,
277
  }
278
  input_keys = set(conditions.keys())
 
290
  compute_posterior: bool = True,
291
  **kwargs,
292
  ) -> Tuple[torch.Tensor, ...]:
293
+ compute_posterior = False # See explanation in prepare_latents
294
  if compute_posterior:
295
  latents = latent_model_conditions.pop("latents")
296
  else:
297
+ latents = latent_model_conditions.pop("latents")
298
+ latents_mean = latent_model_conditions.pop("latents_mean")
299
+ latents_std = latent_model_conditions.pop("latents_std")
300
+
301
+ mu, logvar = torch.chunk(latents, 2, dim=1)
302
+ mu = self._normalize_latents(mu, latents_mean, latents_std)
303
+ logvar = self._normalize_latents(logvar, latents_mean, latents_std)
304
+ latents = torch.cat([mu, logvar], dim=1)
305
+
306
+ posterior = DiagonalGaussianDistribution(latents)
307
  latents = posterior.sample(generator=generator)
308
  del posterior
309
 
310
  noise = torch.zeros_like(latents).normal_(generator=generator)
311
  noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
312
+ timesteps = (sigmas.flatten() * 1000.0).long()
313
 
314
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
315
 
 
 
316
  pred = transformer(
317
  **latent_model_conditions,
318
  **condition_model_conditions,
 
382
  transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
383
  if scheduler is not None:
384
  scheduler.save_pretrained(os.path.join(directory, "scheduler"))
385
+
386
+ @staticmethod
387
+ def _normalize_latents(
388
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
389
+ ) -> torch.Tensor:
390
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
391
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
392
+ latents = ((latents.float() - latents_mean) * latents_std).to(latents)
393
+ return latents
finetrainers/trainer/sft_trainer/trainer.py CHANGED
@@ -147,8 +147,11 @@ class SFTTrainer:
147
 
148
  # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
149
  # parameters to be of the same dtype.
150
- if self.args.training_type == TrainingType.LORA and not parallel_backend.data_sharding_enabled:
151
- cast_training_params([self.transformer], dtype=torch.float32)
 
 
 
152
 
153
  def _prepare_for_training(self) -> None:
154
  # 1. Apply parallelism
 
147
 
148
  # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
149
  # parameters to be of the same dtype.
150
+ if parallel_backend.data_sharding_enabled:
151
+ self.transformer.to(dtype=self.args.transformer_dtype)
152
+ else:
153
+ if self.args.training_type == TrainingType.LORA:
154
+ cast_training_params([self.transformer], dtype=torch.float32)
155
 
156
  def _prepare_for_training(self) -> None:
157
  # 1. Apply parallelism