Spaces:
Running
Running
upgrade Finetrainers
Browse files
finetrainers/models/cogvideox/base_specification.py
CHANGED
@@ -299,7 +299,7 @@ class CogVideoXModelSpecification(ModelSpecification):
|
|
299 |
latents = posterior.sample(generator=generator)
|
300 |
del posterior
|
301 |
|
302 |
-
if not self.vae_config
|
303 |
latents = latents * self.vae_config.scaling_factor
|
304 |
|
305 |
if patch_size_t is not None:
|
|
|
299 |
latents = posterior.sample(generator=generator)
|
300 |
del posterior
|
301 |
|
302 |
+
if not getattr(self.vae_config, "invert_scale_latents", False):
|
303 |
latents = latents * self.vae_config.scaling_factor
|
304 |
|
305 |
if patch_size_t is not None:
|
finetrainers/models/ltx_video/base_specification.py
CHANGED
@@ -336,8 +336,8 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
336 |
latents = self._pack_latents(latents, patch_size, patch_size_t)
|
337 |
noise = self._pack_latents(noise, patch_size, patch_size_t)
|
338 |
noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
|
339 |
-
|
340 |
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
|
|
|
341 |
|
342 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
343 |
|
@@ -352,7 +352,6 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
352 |
vae_spatial_compression_ratio,
|
353 |
vae_spatial_compression_ratio,
|
354 |
]
|
355 |
-
timesteps = (sigmas * 1000.0).long()
|
356 |
|
357 |
pred = transformer(
|
358 |
**latent_model_conditions,
|
@@ -444,9 +443,9 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
444 |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
445 |
) -> torch.Tensor:
|
446 |
# Normalize latents across the channel dimension [B, C, F, H, W]
|
447 |
-
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device
|
448 |
-
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device
|
449 |
-
latents = (latents - latents_mean) * scaling_factor / latents_std
|
450 |
return latents
|
451 |
|
452 |
@staticmethod
|
|
|
336 |
latents = self._pack_latents(latents, patch_size, patch_size_t)
|
337 |
noise = self._pack_latents(noise, patch_size, patch_size_t)
|
338 |
noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
|
|
|
339 |
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
|
340 |
+
timesteps = (sigmas * 1000.0).long()
|
341 |
|
342 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
343 |
|
|
|
352 |
vae_spatial_compression_ratio,
|
353 |
vae_spatial_compression_ratio,
|
354 |
]
|
|
|
355 |
|
356 |
pred = transformer(
|
357 |
**latent_model_conditions,
|
|
|
443 |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
444 |
) -> torch.Tensor:
|
445 |
# Normalize latents across the channel dimension [B, C, F, H, W]
|
446 |
+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
|
447 |
+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
|
448 |
+
latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents)
|
449 |
return latents
|
450 |
|
451 |
@staticmethod
|
finetrainers/models/wan/base_specification.py
CHANGED
@@ -39,7 +39,7 @@ class WanLatentEncodeProcessor(ProcessorMixin):
|
|
39 |
def __init__(self, output_names: List[str]):
|
40 |
super().__init__()
|
41 |
self.output_names = output_names
|
42 |
-
assert len(self.output_names) ==
|
43 |
|
44 |
def forward(
|
45 |
self,
|
@@ -72,7 +72,10 @@ class WanLatentEncodeProcessor(ProcessorMixin):
|
|
72 |
moments = vae._encode(video)
|
73 |
latents = moments.to(dtype=dtype)
|
74 |
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
class WanModelSpecification(ModelSpecification):
|
@@ -108,7 +111,7 @@ class WanModelSpecification(ModelSpecification):
|
|
108 |
if condition_model_processors is None:
|
109 |
condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
|
110 |
if latent_model_processors is None:
|
111 |
-
latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
|
112 |
|
113 |
self.condition_model_processors = condition_model_processors
|
114 |
self.latent_model_processors = latent_model_processors
|
@@ -266,7 +269,10 @@ class WanModelSpecification(ModelSpecification):
|
|
266 |
"image": image,
|
267 |
"video": video,
|
268 |
"generator": generator,
|
269 |
-
|
|
|
|
|
|
|
270 |
**kwargs,
|
271 |
}
|
272 |
input_keys = set(conditions.keys())
|
@@ -284,20 +290,29 @@ class WanModelSpecification(ModelSpecification):
|
|
284 |
compute_posterior: bool = True,
|
285 |
**kwargs,
|
286 |
) -> Tuple[torch.Tensor, ...]:
|
|
|
287 |
if compute_posterior:
|
288 |
latents = latent_model_conditions.pop("latents")
|
289 |
else:
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
latents = posterior.sample(generator=generator)
|
292 |
del posterior
|
293 |
|
294 |
noise = torch.zeros_like(latents).normal_(generator=generator)
|
295 |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
|
|
|
296 |
|
297 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
298 |
|
299 |
-
timesteps = (sigmas.flatten() * 1000.0).long()
|
300 |
-
|
301 |
pred = transformer(
|
302 |
**latent_model_conditions,
|
303 |
**condition_model_conditions,
|
@@ -367,3 +382,12 @@ class WanModelSpecification(ModelSpecification):
|
|
367 |
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
|
368 |
if scheduler is not None:
|
369 |
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def __init__(self, output_names: List[str]):
|
40 |
super().__init__()
|
41 |
self.output_names = output_names
|
42 |
+
assert len(self.output_names) == 3
|
43 |
|
44 |
def forward(
|
45 |
self,
|
|
|
72 |
moments = vae._encode(video)
|
73 |
latents = moments.to(dtype=dtype)
|
74 |
|
75 |
+
latents_mean = torch.tensor(vae.config.latents_mean)
|
76 |
+
latents_std = 1.0 / torch.tensor(vae.config.latents_std)
|
77 |
+
|
78 |
+
return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std}
|
79 |
|
80 |
|
81 |
class WanModelSpecification(ModelSpecification):
|
|
|
111 |
if condition_model_processors is None:
|
112 |
condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
|
113 |
if latent_model_processors is None:
|
114 |
+
latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
|
115 |
|
116 |
self.condition_model_processors = condition_model_processors
|
117 |
self.latent_model_processors = latent_model_processors
|
|
|
269 |
"image": image,
|
270 |
"video": video,
|
271 |
"generator": generator,
|
272 |
+
# We must force this to False because the latent normalization should be done before
|
273 |
+
# the posterior is computed. The VAE does not handle this any more:
|
274 |
+
# https://github.com/huggingface/diffusers/pull/10998
|
275 |
+
"compute_posterior": False,
|
276 |
**kwargs,
|
277 |
}
|
278 |
input_keys = set(conditions.keys())
|
|
|
290 |
compute_posterior: bool = True,
|
291 |
**kwargs,
|
292 |
) -> Tuple[torch.Tensor, ...]:
|
293 |
+
compute_posterior = False # See explanation in prepare_latents
|
294 |
if compute_posterior:
|
295 |
latents = latent_model_conditions.pop("latents")
|
296 |
else:
|
297 |
+
latents = latent_model_conditions.pop("latents")
|
298 |
+
latents_mean = latent_model_conditions.pop("latents_mean")
|
299 |
+
latents_std = latent_model_conditions.pop("latents_std")
|
300 |
+
|
301 |
+
mu, logvar = torch.chunk(latents, 2, dim=1)
|
302 |
+
mu = self._normalize_latents(mu, latents_mean, latents_std)
|
303 |
+
logvar = self._normalize_latents(logvar, latents_mean, latents_std)
|
304 |
+
latents = torch.cat([mu, logvar], dim=1)
|
305 |
+
|
306 |
+
posterior = DiagonalGaussianDistribution(latents)
|
307 |
latents = posterior.sample(generator=generator)
|
308 |
del posterior
|
309 |
|
310 |
noise = torch.zeros_like(latents).normal_(generator=generator)
|
311 |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
|
312 |
+
timesteps = (sigmas.flatten() * 1000.0).long()
|
313 |
|
314 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
315 |
|
|
|
|
|
316 |
pred = transformer(
|
317 |
**latent_model_conditions,
|
318 |
**condition_model_conditions,
|
|
|
382 |
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
|
383 |
if scheduler is not None:
|
384 |
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
|
385 |
+
|
386 |
+
@staticmethod
|
387 |
+
def _normalize_latents(
|
388 |
+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
389 |
+
) -> torch.Tensor:
|
390 |
+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
|
391 |
+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
|
392 |
+
latents = ((latents.float() - latents_mean) * latents_std).to(latents)
|
393 |
+
return latents
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
@@ -147,8 +147,11 @@ class SFTTrainer:
|
|
147 |
|
148 |
# Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
|
149 |
# parameters to be of the same dtype.
|
150 |
-
if
|
151 |
-
|
|
|
|
|
|
|
152 |
|
153 |
def _prepare_for_training(self) -> None:
|
154 |
# 1. Apply parallelism
|
|
|
147 |
|
148 |
# Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
|
149 |
# parameters to be of the same dtype.
|
150 |
+
if parallel_backend.data_sharding_enabled:
|
151 |
+
self.transformer.to(dtype=self.args.transformer_dtype)
|
152 |
+
else:
|
153 |
+
if self.args.training_type == TrainingType.LORA:
|
154 |
+
cast_training_params([self.transformer], dtype=torch.float32)
|
155 |
|
156 |
def _prepare_for_training(self) -> None:
|
157 |
# 1. Apply parallelism
|