Spaces:
Sleeping
Sleeping
File size: 4,882 Bytes
ca25718 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import logging
import math
from typing import Dict, List, Optional, Tuple
import PIL
import PIL.Image
import torch
from diffusers import DiffusionPipeline
from rewards import clip_img_transform
from rewards.base_reward import BaseRewardLoss
class LatentNoiseTrainer:
"""Trainer for optimizing latents with reward losses."""
def __init__(
self,
reward_losses: List[BaseRewardLoss],
model: DiffusionPipeline,
n_iters: int,
n_inference_steps: int,
seed: int,
no_optim: bool = False,
regularize: bool = True,
regularization_weight: float = 0.01,
grad_clip: float = 0.1,
log_metrics: bool = True,
save_all_images: bool = False,
imageselect: bool = False,
device: torch.device = torch.device("cuda"),
):
self.reward_losses = reward_losses
self.model = model
self.n_iters = n_iters
self.n_inference_steps = n_inference_steps
self.seed = seed
self.no_optim = no_optim
self.regularize = regularize
self.regularization_weight = regularization_weight
self.grad_clip = grad_clip
self.log_metrics = log_metrics
self.save_all_images = save_all_images
self.imageselect = imageselect
self.device = device
self.preprocess_fn = clip_img_transform(224)
def train(
self,
latents: torch.Tensor,
prompt: str,
optimizer: torch.optim.Optimizer,
save_dir: Optional[str] = None,
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
logging.info(f"Optimizing latents for prompt '{prompt}'.")
best_loss = torch.inf
best_image = None
initial_rewards = None
best_rewards = None
latent_dim = math.prod(latents.shape[1:])
for iteration in range(self.n_iters):
to_log = ""
rewards = {}
optimizer.zero_grad()
generator = torch.Generator("cuda").manual_seed(self.seed)
if self.imageselect:
new_latents = torch.randn_like(
latents, device=self.device, dtype=latents.dtype
)
image = self.model.apply(
new_latents,
prompt,
generator=generator,
num_inference_steps=self.n_inference_steps,
)
else:
image = self.model.apply(
latents,
prompt,
generator=generator,
num_inference_steps=self.n_inference_steps,
)
if self.no_optim:
best_image = image
break
total_loss = 0
preprocessed_image = self.preprocess_fn(image)
for reward_loss in self.reward_losses:
loss = reward_loss(preprocessed_image, prompt)
to_log += f"{reward_loss.name}: {loss.item():.4f}, "
total_loss += loss * reward_loss.weighting
rewards[reward_loss.name] = loss.item()
rewards["total"] = total_loss.item()
to_log += f"Total: {total_loss.item():.4f}"
total_reward_loss = total_loss.item()
if self.regularize:
# compute in fp32 to avoid overflow
latent_norm = torch.linalg.vector_norm(latents).to(torch.float32)
log_norm = torch.log(latent_norm)
regularization = self.regularization_weight * (
0.5 * latent_norm**2 - (latent_dim - 1) * log_norm
)
to_log += f", Latent norm: {latent_norm.item()}"
rewards["norm"] = latent_norm.item()
total_loss += regularization.to(total_loss.dtype)
if self.log_metrics:
logging.info(f"Iteration {iteration}: {to_log}")
if initial_rewards is None:
initial_rewards = rewards
if total_reward_loss < best_loss:
best_loss = total_reward_loss
best_image = image
best_rewards = rewards
if iteration != self.n_iters - 1 and not self.imageselect:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(latents, self.grad_clip)
optimizer.step()
if self.save_all_images:
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
image_pil.save(f"{save_dir}/{iteration}.png")
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
return image_pil, initial_rewards, best_rewards
|