Spaces:
Sleeping
Sleeping
import logging | |
import math | |
from typing import Dict, List, Optional, Tuple | |
import PIL | |
import PIL.Image | |
import torch | |
from diffusers import DiffusionPipeline | |
from rewards import clip_img_transform | |
from rewards.base_reward import BaseRewardLoss | |
class LatentNoiseTrainer: | |
"""Trainer for optimizing latents with reward losses.""" | |
def __init__( | |
self, | |
reward_losses: List[BaseRewardLoss], | |
model: DiffusionPipeline, | |
n_iters: int, | |
n_inference_steps: int, | |
seed: int, | |
no_optim: bool = False, | |
regularize: bool = True, | |
regularization_weight: float = 0.01, | |
grad_clip: float = 0.1, | |
log_metrics: bool = True, | |
save_all_images: bool = False, | |
imageselect: bool = False, | |
device: torch.device = torch.device("cuda"), | |
): | |
self.reward_losses = reward_losses | |
self.model = model | |
self.n_iters = n_iters | |
self.n_inference_steps = n_inference_steps | |
self.seed = seed | |
self.no_optim = no_optim | |
self.regularize = regularize | |
self.regularization_weight = regularization_weight | |
self.grad_clip = grad_clip | |
self.log_metrics = log_metrics | |
self.save_all_images = save_all_images | |
self.imageselect = imageselect | |
self.device = device | |
self.preprocess_fn = clip_img_transform(224) | |
def train( | |
self, | |
latents: torch.Tensor, | |
prompt: str, | |
optimizer: torch.optim.Optimizer, | |
save_dir: Optional[str] = None, | |
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]: | |
logging.info(f"Optimizing latents for prompt '{prompt}'.") | |
best_loss = torch.inf | |
best_image = None | |
initial_rewards = None | |
best_rewards = None | |
latent_dim = math.prod(latents.shape[1:]) | |
for iteration in range(self.n_iters): | |
to_log = "" | |
rewards = {} | |
optimizer.zero_grad() | |
generator = torch.Generator("cuda").manual_seed(self.seed) | |
if self.imageselect: | |
new_latents = torch.randn_like( | |
latents, device=self.device, dtype=latents.dtype | |
) | |
image = self.model.apply( | |
new_latents, | |
prompt, | |
generator=generator, | |
num_inference_steps=self.n_inference_steps, | |
) | |
else: | |
image = self.model.apply( | |
latents, | |
prompt, | |
generator=generator, | |
num_inference_steps=self.n_inference_steps, | |
) | |
if self.no_optim: | |
best_image = image | |
break | |
total_loss = 0 | |
preprocessed_image = self.preprocess_fn(image) | |
for reward_loss in self.reward_losses: | |
loss = reward_loss(preprocessed_image, prompt) | |
to_log += f"{reward_loss.name}: {loss.item():.4f}, " | |
total_loss += loss * reward_loss.weighting | |
rewards[reward_loss.name] = loss.item() | |
rewards["total"] = total_loss.item() | |
to_log += f"Total: {total_loss.item():.4f}" | |
total_reward_loss = total_loss.item() | |
if self.regularize: | |
# compute in fp32 to avoid overflow | |
latent_norm = torch.linalg.vector_norm(latents).to(torch.float32) | |
log_norm = torch.log(latent_norm) | |
regularization = self.regularization_weight * ( | |
0.5 * latent_norm**2 - (latent_dim - 1) * log_norm | |
) | |
to_log += f", Latent norm: {latent_norm.item()}" | |
rewards["norm"] = latent_norm.item() | |
total_loss += regularization.to(total_loss.dtype) | |
if self.log_metrics: | |
logging.info(f"Iteration {iteration}: {to_log}") | |
if initial_rewards is None: | |
initial_rewards = rewards | |
if total_reward_loss < best_loss: | |
best_loss = total_reward_loss | |
best_image = image | |
best_rewards = rewards | |
if iteration != self.n_iters - 1 and not self.imageselect: | |
total_loss.backward() | |
torch.nn.utils.clip_grad_norm_(latents, self.grad_clip) | |
optimizer.step() | |
if self.save_all_images: | |
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() | |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0] | |
image_pil.save(f"{save_dir}/{iteration}.png") | |
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy() | |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0] | |
return image_pil, initial_rewards, best_rewards | |