Spaces:
Running
Running
import os, torch | |
import os.path as osp | |
import shutil | |
import numpy as np | |
import copy | |
import torch.nn as nn | |
from tqdm.auto import tqdm | |
from accelerate import Accelerator | |
from torchvision.utils import make_grid, save_image | |
from torch.utils.data import DataLoader, DistributedSampler | |
from semanticist.utils.logger import SmoothedValue, MetricLogger, empty_cache | |
from accelerate.utils import DistributedDataParallelKwargs | |
from semanticist.stage2.gpt import GPT_models | |
from semanticist.stage2.generate import generate | |
from pathlib import Path | |
import time | |
from semanticist.engine.trainer_utils import ( | |
instantiate_from_config, concat_all_gather, | |
save_img_batch, get_fid_stats, | |
EMAModel, create_scheduler, load_state_dict, load_safetensors, | |
setup_result_folders, create_optimizer, | |
CacheDataLoader | |
) | |
class GPTTrainer(nn.Module): | |
def __init__( | |
self, | |
ae_model, | |
gpt_model, | |
dataset, | |
test_only=False, | |
num_test_images=50000, | |
num_epoch=400, | |
eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear | |
blr=1e-4, | |
cosine_lr=False, | |
lr_min=0, | |
warmup_epochs=100, | |
warmup_steps=None, | |
warmup_lr_init=0, | |
decay_steps=None, | |
batch_size=32, | |
cache_bs=8, | |
test_bs=100, | |
num_workers=8, | |
pin_memory=False, | |
max_grad_norm=None, | |
grad_accum_steps=1, | |
precision='bf16', | |
save_every=10000, | |
sample_every=1000, | |
fid_every=50000, | |
result_folder=None, | |
log_dir="./log", | |
ae_cfg=1.0, | |
cfg=6.0, | |
cfg_schedule="linear", | |
temperature=1.0, | |
train_num_slots=None, | |
test_num_slots=None, | |
eval_fid=False, | |
fid_stats=None, | |
enable_ema=False, | |
compile=False, | |
enable_cache_latents=True, | |
cache_dir='/dev/shm/slot_cache' | |
): | |
super().__init__() | |
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) | |
self.accelerator = Accelerator( | |
kwargs_handlers=[kwargs], | |
mixed_precision=precision, | |
gradient_accumulation_steps=grad_accum_steps, | |
log_with="tensorboard", | |
project_dir=log_dir, | |
) | |
self.ae_model = instantiate_from_config(ae_model) | |
ae_model_path = ae_model.params.ckpt_path | |
assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl") | |
assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist" | |
self._load_checkpoint(ae_model_path, self.ae_model) | |
self.ae_model.to(self.device) | |
for param in self.ae_model.parameters(): | |
param.requires_grad = False | |
self.ae_model.eval() | |
self.model_name = gpt_model.target | |
if 'GPT' in gpt_model.target: | |
self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params) | |
else: | |
raise ValueError(f"Unknown model type: {gpt_model.target}") | |
self.num_slots = ae_model.params.num_slots | |
self.slot_dim = ae_model.params.slot_dim | |
self.test_only = test_only | |
self.test_bs = test_bs | |
self.num_test_images = num_test_images | |
self.num_classes = gpt_model.params.num_classes | |
self.batch_size = batch_size | |
if not test_only: | |
self.train_ds = instantiate_from_config(dataset) | |
train_size = len(self.train_ds) | |
if self.accelerator.is_main_process: | |
print(f"train dataset size: {train_size}") | |
sampler = DistributedSampler( | |
self.train_ds, | |
num_replicas=self.accelerator.num_processes, | |
rank=self.accelerator.process_index, | |
shuffle=True, | |
) | |
self.train_dl = DataLoader( | |
self.train_ds, | |
batch_size=batch_size if not enable_cache_latents else cache_bs, | |
sampler=sampler, | |
num_workers=num_workers, | |
pin_memory=pin_memory, | |
drop_last=True, | |
) | |
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes | |
lr = blr * effective_bs / 256 | |
if self.accelerator.is_main_process: | |
print(f"Effective batch size is {effective_bs}") | |
self.g_optim = create_optimizer(self.gpt_model, weight_decay=0.05, learning_rate=lr) | |
if warmup_epochs is not None: | |
warmup_steps = warmup_epochs * len(self.train_dl) | |
self.g_sched = create_scheduler( | |
self.g_optim, | |
num_epoch, | |
len(self.train_dl), | |
lr_min, | |
warmup_steps, | |
warmup_lr_init, | |
decay_steps, | |
cosine_lr | |
) | |
self.accelerator.register_for_checkpointing(self.g_sched) | |
self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare(self.gpt_model, self.g_optim, self.g_sched) | |
else: | |
self.gpt_model = self.accelerator.prepare(self.gpt_model) | |
self.steps = 0 | |
self.loaded_steps = -1 | |
if compile: | |
self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead") | |
_model = self.accelerator.unwrap_model(self.gpt_model) | |
_model = torch.compile(_model, mode="reduce-overhead") | |
self.enable_ema = enable_ema | |
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here | |
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device) | |
self.accelerator.register_for_checkpointing(self.ema_model) | |
self._load_checkpoint(gpt_model.params.ckpt_path) | |
if self.test_only: | |
self.steps = self.loaded_steps | |
self.num_epoch = num_epoch | |
self.save_every = save_every | |
self.sample_every = sample_every | |
self.fid_every = fid_every | |
self.max_grad_norm = max_grad_norm | |
self.eval_classes = eval_classes | |
self.cfg = cfg | |
self.ae_cfg = ae_cfg | |
self.cfg_schedule = cfg_schedule | |
self.temperature = temperature | |
self.train_num_slots = train_num_slots | |
self.test_num_slots = test_num_slots | |
if self.train_num_slots is not None: | |
self.train_num_slots = min(self.train_num_slots, self.num_slots) | |
else: | |
self.train_num_slots = self.num_slots | |
if self.test_num_slots is not None: | |
self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots) | |
else: | |
self.num_slots_to_gen = self.train_num_slots | |
self.eval_fid = eval_fid | |
if eval_fid: | |
assert fid_stats is not None | |
self.fid_stats = fid_stats | |
# Setup result folders | |
self.result_folder = result_folder | |
self.model_saved_dir, self.image_saved_dir = setup_result_folders(result_folder) | |
# Setup cache | |
self.cache_dir = Path(cache_dir) | |
self.enable_cache_latents = enable_cache_latents | |
self.cache_loader = None | |
def device(self): | |
return self.accelerator.device | |
def _load_checkpoint(self, ckpt_path=None, model=None): | |
if ckpt_path is None or not osp.exists(ckpt_path): | |
return | |
if model is None: | |
model = self.accelerator.unwrap_model(self.gpt_model) | |
if osp.isdir(ckpt_path): | |
self.loaded_steps = int( | |
ckpt_path.split("step")[-1].split("/")[0] | |
) | |
if not self.test_only: | |
self.accelerator.load_state(ckpt_path) | |
else: | |
if self.enable_ema: | |
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl") | |
if osp.exists(model_path): | |
state_dict = torch.load(model_path, map_location="cpu") | |
load_state_dict(state_dict, model) | |
if self.accelerator.is_main_process: | |
print(f"Loaded ema model from {model_path}") | |
else: | |
model_path = osp.join(ckpt_path, "model.safetensors") | |
if osp.exists(model_path): | |
load_safetensors(model_path, model) | |
else: | |
if ckpt_path.endswith(".safetensors"): | |
load_safetensors(ckpt_path, model) | |
else: | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
load_state_dict(state_dict, model) | |
if self.accelerator.is_main_process: | |
print(f"Loaded checkpoint from {ckpt_path}") | |
def _build_cache(self): | |
"""Build cache for slots and targets.""" | |
rank = self.accelerator.process_index | |
world_size = self.accelerator.num_processes | |
# Clean up any existing cache files first | |
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" | |
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" | |
if slots_file.exists(): | |
os.remove(slots_file) | |
if targets_file.exists(): | |
os.remove(targets_file) | |
dataset_size = len(self.train_dl.dataset) | |
shard_size = dataset_size // world_size | |
# Detect number of augmentations from first batch | |
with torch.no_grad(): | |
sample_batch = next(iter(self.train_dl)) | |
img, _ = sample_batch | |
num_augs = img.shape[1] if len(img.shape) == 5 else 1 | |
print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...") | |
os.makedirs(self.cache_dir, exist_ok=True) | |
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" | |
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" | |
# Create memory-mapped files | |
slots_mmap = np.memmap( | |
slots_file, | |
dtype='float32', | |
mode='w+', | |
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim) | |
) | |
targets_mmap = np.memmap( | |
targets_file, | |
dtype='int64', | |
mode='w+', | |
shape=(shard_size * num_augs,) | |
) | |
# Cache data | |
with torch.no_grad(): | |
for i, batch in enumerate(tqdm( | |
self.train_dl, | |
desc=f"Rank {rank}: Caching data", | |
disable=not self.accelerator.is_local_main_process | |
)): | |
imgs, targets = batch | |
if len(imgs.shape) == 5: # [B, num_augs, C, H, W] | |
B, A, C, H, W = imgs.shape | |
imgs = imgs.view(-1, C, H, W) # [B*num_augs, C, H, W] | |
targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) # [B*num_augs] | |
# Split imgs into n chunks | |
num_splits = num_augs | |
split_size = imgs.shape[0] // num_splits | |
imgs_splits = torch.split(imgs, split_size) | |
targets_splits = torch.split(targets, split_size) | |
start_idx = i * self.train_dl.batch_size * num_augs | |
for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)): | |
img_split = img_split.to(self.device, non_blocking=True) | |
slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :] | |
split_start = start_idx + (split_idx * split_size) | |
split_end = split_start + img_split.shape[0] | |
# Write directly to mmap files | |
slots_mmap[split_start:split_end] = slots_split.cpu().numpy() | |
targets_mmap[split_start:split_end] = targets_split.numpy() | |
# Close the mmap files | |
del slots_mmap | |
del targets_mmap | |
# Reopen in read mode | |
self.cached_latents = np.memmap( | |
slots_file, | |
dtype='float32', | |
mode='r', | |
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim) | |
) | |
self.cached_targets = np.memmap( | |
targets_file, | |
dtype='int64', | |
mode='r', | |
shape=(shard_size * num_augs,) | |
) | |
# Store the number of augmentations for the cache loader | |
self.num_augs = num_augs | |
def _setup_cache(self): | |
"""Setup cache if enabled.""" | |
self._build_cache() | |
self.accelerator.wait_for_everyone() | |
# Initialize cache loader if cache exists | |
if self.cached_latents is not None: | |
self.cache_loader = CacheDataLoader( | |
slots=self.cached_latents, | |
targets=self.cached_targets, | |
batch_size=self.batch_size, | |
num_augs=self.num_augs, | |
seed=42 + self.accelerator.process_index | |
) | |
def __del__(self): | |
"""Cleanup cache files.""" | |
if self.enable_cache_latents: | |
rank = self.accelerator.process_index | |
world_size = self.accelerator.num_processes | |
# Clean up slots cache | |
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" | |
if slots_file.exists(): | |
os.remove(slots_file) | |
# Clean up targets cache | |
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" | |
if targets_file.exists(): | |
os.remove(targets_file) | |
def _train_step(self, slots, targets=None): | |
"""Execute single training step.""" | |
with self.accelerator.accumulate(self.gpt_model): | |
with self.accelerator.autocast(): | |
loss = self.gpt_model(slots, targets) | |
self.accelerator.backward(loss) | |
if self.accelerator.sync_gradients and self.max_grad_norm is not None: | |
self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm) | |
self.g_optim.step() | |
if self.g_sched is not None: | |
self.g_sched.step_update(self.steps) | |
self.g_optim.zero_grad() | |
# Update EMA model if enabled | |
if self.enable_ema: | |
self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model)) | |
return loss | |
def _train_epoch_cached(self, epoch, logger): | |
"""Train one epoch using cached data.""" | |
self.cache_loader.set_epoch(epoch) | |
header = f'Epoch: [{epoch}/{self.num_epoch}]' | |
for batch in logger.log_every(self.cache_loader, 20, header): | |
slots, targets = (b.to(self.device, non_blocking=True) for b in batch) | |
self.steps += 1 | |
if self.steps == 1: | |
print(f"Training batch size: {len(slots)}") | |
print(f"Hello from index {self.accelerator.local_process_index}") | |
loss = self._train_step(slots, targets) | |
self._handle_periodic_ops(loss, logger) | |
def _train_epoch_uncached(self, epoch, logger): | |
"""Train one epoch using raw data.""" | |
header = f'Epoch: [{epoch}/{self.num_epoch}]' | |
for batch in logger.log_every(self.train_dl, 20, header): | |
img, targets = (b.to(self.device, non_blocking=True) for b in batch) | |
self.steps += 1 | |
if self.steps == 1: | |
print(f"Training batch size: {img.size(0)}") | |
print(f"Hello from index {self.accelerator.local_process_index}") | |
slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :] | |
loss = self._train_step(slots, targets) | |
self._handle_periodic_ops(loss, logger) | |
def _handle_periodic_ops(self, loss, logger): | |
"""Handle periodic operations and logging.""" | |
logger.update(loss=loss.item()) | |
logger.update(lr=self.g_optim.param_groups[0]["lr"]) | |
if self.steps % self.save_every == 0: | |
self.save() | |
if (self.steps % self.sample_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0): | |
empty_cache() | |
self.evaluate() | |
self.accelerator.wait_for_everyone() | |
empty_cache() | |
def _save_config(self, config): | |
"""Save configuration file.""" | |
if config is not None and self.accelerator.is_main_process: | |
import shutil | |
from omegaconf import OmegaConf | |
if isinstance(config, str) and osp.exists(config): | |
shutil.copy(config, osp.join(self.result_folder, "config.yaml")) | |
else: | |
config_save_path = osp.join(self.result_folder, "config.yaml") | |
OmegaConf.save(config, config_save_path) | |
def _should_skip_epoch(self, epoch): | |
"""Check if epoch should be skipped due to loaded checkpoint.""" | |
loader = self.train_dl if not self.enable_cache_latents else self.cache_loader | |
if ((epoch + 1) * len(loader)) <= self.loaded_steps: | |
if self.accelerator.is_main_process: | |
print(f"Epoch {epoch} is skipped because it is loaded from ckpt") | |
self.steps += len(loader) | |
return True | |
if self.steps < self.loaded_steps: | |
for _ in loader: | |
self.steps += 1 | |
if self.steps >= self.loaded_steps: | |
break | |
return False | |
def train(self, config=None): | |
"""Main training loop.""" | |
# Initial setup | |
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
if self.accelerator.is_main_process: | |
print(f"number of learnable parameters: {n_parameters//1e6}M") | |
self._save_config(config) | |
self.accelerator.init_trackers("gpt") | |
# Handle test-only mode | |
if self.test_only: | |
empty_cache() | |
self.evaluate() | |
self.accelerator.wait_for_everyone() | |
empty_cache() | |
return | |
# Setup cache if enabled | |
if self.enable_cache_latents: | |
self._setup_cache() | |
# Training loop | |
for epoch in range(self.num_epoch): | |
if self._should_skip_epoch(epoch): | |
continue | |
self.gpt_model.train() | |
logger = MetricLogger(delimiter=" ") | |
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) | |
# Choose training path based on cache availability | |
if self.enable_cache_latents: | |
self._train_epoch_cached(epoch, logger) | |
else: | |
self._train_epoch_uncached(epoch, logger) | |
# Synchronize and log epoch stats | |
logger.synchronize_between_processes() | |
if self.accelerator.is_main_process: | |
print("Averaged stats:", logger) | |
# Finish training | |
self.accelerator.end_training() | |
self.save() | |
if self.accelerator.is_main_process: | |
print("Train finished!") | |
def save(self): | |
self.accelerator.wait_for_everyone() | |
self.accelerator.save_state( | |
os.path.join(self.model_saved_dir, f"step{self.steps}") | |
) | |
def evaluate(self, use_ema=True): | |
self.gpt_model.eval() | |
unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model) | |
# switch to ema params, only when eval_fid is True | |
# if test_only, we directly load the ema dict and skip here | |
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only | |
if use_ema: | |
if hasattr(self, "ema_model"): | |
model_without_ddp = self.accelerator.unwrap_model(self.gpt_model) | |
model_state_dict = copy.deepcopy(model_without_ddp.state_dict()) | |
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) | |
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): | |
if "nested_sampler" in name: | |
continue | |
ema_state_dict[name] = self.ema_model.state_dict()[name] | |
if self.accelerator.is_main_process: | |
print("Switch to ema") | |
model_without_ddp.load_state_dict(ema_state_dict) | |
else: | |
print("EMA model not found, using original model") | |
use_ema = False | |
if not self.test_only: | |
classes = torch.tensor(self.eval_classes, device=self.device) | |
with self.accelerator.autocast(): | |
slots = generate(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, cfg_schedule=self.cfg_schedule, temperature=self.temperature) | |
if self.num_slots_to_gen < self.num_slots: | |
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1) | |
null_slots = null_slots[:, self.num_slots_to_gen:, :] | |
slots = torch.cat([slots, null_slots], dim=1) | |
imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now | |
imgs = concat_all_gather(imgs) | |
if self.accelerator.num_processes > 16: | |
imgs = imgs[:16*len(self.eval_classes)] | |
imgs = imgs.detach().cpu() | |
grid = make_grid( | |
imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1) | |
) | |
if self.accelerator.is_main_process: | |
save_image( | |
grid, | |
os.path.join( | |
self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg" | |
), | |
) | |
if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)): | |
# Create output directory (only on main process) | |
save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}") | |
if self.accelerator.is_main_process: | |
os.makedirs(save_folder, exist_ok=True) | |
# Setup for distributed generation | |
world_size = self.accelerator.num_processes | |
local_rank = self.accelerator.process_index | |
batch_size = self.test_bs | |
# Create balanced class distribution | |
num_classes = self.num_classes | |
images_per_class = self.num_test_images // num_classes | |
class_labels = np.repeat(np.arange(num_classes), images_per_class) | |
# Shuffle the class labels to ensure random ordering | |
np.random.shuffle(class_labels) | |
total_images = len(class_labels) | |
padding_size = world_size * batch_size - (total_images % (world_size * batch_size)) | |
class_labels = np.pad(class_labels, (0, padding_size), 'constant') | |
padded_total_images = len(class_labels) | |
# Distribute workload across GPUs | |
images_per_gpu = padded_total_images // world_size | |
start_idx = local_rank * images_per_gpu | |
end_idx = min(start_idx + images_per_gpu, padded_total_images) | |
local_class_labels = class_labels[start_idx:end_idx] | |
local_num_steps = len(local_class_labels) // batch_size | |
if self.accelerator.is_main_process: | |
print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs") | |
used_time = 0 | |
gen_img_cnt = 0 | |
for i in range(local_num_steps): | |
if self.accelerator.is_main_process and i % 10 == 0: | |
print(f"Generation step {i}/{local_num_steps}") | |
# Get and pad labels for current batch | |
batch_start = i * batch_size | |
batch_end = batch_start + batch_size | |
labels = local_class_labels[batch_start:batch_end] | |
# Convert to tensors and track real vs padding | |
labels = torch.tensor(labels, device=self.device) | |
# Generate images | |
self.accelerator.wait_for_everyone() | |
start_time = time.time() | |
with torch.no_grad(): | |
with self.accelerator.autocast(): | |
slots = generate(unwraped_gpt_model, labels, self.num_slots_to_gen, | |
cfg_scale=self.cfg, | |
cfg_schedule=self.cfg_schedule, | |
temperature=self.temperature) | |
if self.num_slots_to_gen < self.num_slots: | |
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1) | |
null_slots = null_slots[:, self.num_slots_to_gen:, :] | |
slots = torch.cat([slots, null_slots], dim=1) | |
imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg) | |
samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt) | |
# Update timing stats | |
used_time += time.time() - start_time | |
gen_img_cnt += samples_in_batch | |
if self.accelerator.is_main_process and i % 10 == 0: | |
print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image") | |
gathered_imgs = concat_all_gather(imgs) | |
gathered_imgs = gathered_imgs[:samples_in_batch] | |
# Save images (only on main process) | |
if self.accelerator.is_main_process: | |
real_imgs = gathered_imgs.detach().cpu() | |
save_paths = [ | |
os.path.join(save_folder, f"{str(idx).zfill(5)}.png") | |
for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt) | |
] | |
save_img_batch(real_imgs, save_paths) | |
# Calculate metrics (only on main process) | |
self.accelerator.wait_for_everyone() | |
if self.accelerator.is_main_process: | |
generated_files = len(os.listdir(save_folder)) | |
print(f"Generated {generated_files} images out of {total_images} expected") | |
metrics_dict = get_fid_stats(save_folder, None, self.fid_stats) | |
fid = metrics_dict["frechet_inception_distance"] | |
inception_score = metrics_dict["inception_score_mean"] | |
metric_prefix = "fid_ema" if use_ema else "fid" | |
isc_prefix = "isc_ema" if use_ema else "isc" | |
self.accelerator.log({ | |
metric_prefix: fid, | |
isc_prefix: inception_score, | |
"gpt_cfg": self.cfg, | |
"ae_cfg": self.ae_cfg, | |
"cfg_schedule": self.cfg_schedule, | |
"temperature": self.temperature, | |
"num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots | |
}, step=self.steps) | |
# Print comprehensive CFG information | |
cfg_info = ( | |
f"{'EMA ' if use_ema else ''}CFG params: " | |
f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, " | |
f"cfg_schedule={self.cfg_schedule}, " | |
f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, " | |
f"temperature={self.temperature}" | |
) | |
print(cfg_info) | |
print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}") | |
# Cleanup | |
shutil.rmtree(save_folder) | |
# back to no ema | |
if use_ema: | |
if self.accelerator.is_main_process: | |
print("Switch back from ema") | |
model_without_ddp.load_state_dict(model_state_dict) | |
self.gpt_model.train() | |