import os, torch import cv2 import numpy as np import torch_fidelity from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor import importlib from torch.optim import AdamW from semanticist.utils.lr_scheduler import build_scheduler def get_obj_from_str(string, reload=False): """Get object from string path.""" module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): """Instantiate an object from a config dictionary.""" if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def is_dist_avail_and_initialized(): """Check if distributed training is available and initialized.""" if not torch.distributed.is_initialized(): return False return True def is_main_process(): """Check if the current process is the main process.""" return not is_dist_avail_and_initialized() or torch.distributed.get_rank() == 0 def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def requires_grad(model, flag=True): """Set requires_grad flag for all model parameters.""" for p in model.parameters(): p.requires_grad = flag def save_img(img, save_path): """Save a single image to disk.""" img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255) img = img.astype(np.uint8)[:, :, ::-1] cv2.imwrite(save_path, img) def save_img_batch(imgs, save_paths): """Process and save multiple images at once using a thread pool.""" # Convert to numpy and prepare all images in one go imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8) imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once with ThreadPoolExecutor(max_workers=32) as pool: # Submit all tasks at once futures = [pool.submit(cv2.imwrite, path, img) for path, img in zip(save_paths, imgs)] # Wait for all tasks to complete for future in futures: future.result() # This will raise any exceptions that occurred def get_fid_stats(real_dir, rec_dir, fid_stats): """Calculate FID statistics between real and reconstructed images.""" stats = torch_fidelity.calculate_metrics( input1=rec_dir, input2=real_dir, fid_statistics_file=fid_stats, cuda=True, isc=True, fid=True, kid=False, prc=False, verbose=False, ) return stats def create_scheduler(optimizer, num_epoch, steps_per_epoch, lr_min, warmup_steps, warmup_lr_init, decay_steps, cosine_lr): """Create a learning rate scheduler.""" scheduler = build_scheduler( optimizer, num_epoch, steps_per_epoch, lr_min, warmup_steps, warmup_lr_init, decay_steps, cosine_lr, ) return scheduler def load_state_dict(state_dict, model): """Helper to load a state dict with proper prefix handling.""" if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] # Remove '_orig_mod' prefix if present state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} missing, unexpected = model.load_state_dict( state_dict, strict=False ) if is_main_process(): print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") def load_safetensors(path, model): """Helper to load a safetensors checkpoint.""" from safetensors.torch import safe_open with safe_open(path, framework="pt", device="cpu") as f: state_dict = {k: f.get_tensor(k) for k in f.keys()} load_state_dict(state_dict, model) def setup_result_folders(result_folder): """Setup result folders for saving models and images.""" model_saved_dir = os.path.join(result_folder, "models") os.makedirs(model_saved_dir, exist_ok=True) image_saved_dir = os.path.join(result_folder, "images") os.makedirs(image_saved_dir, exist_ok=True) return model_saved_dir, image_saved_dir def create_optimizer(model, weight_decay, learning_rate, betas=(0.9, 0.95)): """Create an AdamW optimizer with weight decay for 2D parameters only.""" # start with all of the candidate parameters param_dict = {pn: p for pn, p in model.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) if is_main_process(): print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas) return optimizer class EMAModel: """Model Exponential Moving Average.""" def __init__(self, model, device, decay=0.999): self.device = device self.decay = decay self.ema_params = OrderedDict( (name, param.clone().detach().to(device)) for name, param in model.named_parameters() if param.requires_grad ) @torch.no_grad() def update(self, model): for name, param in model.named_parameters(): if param.requires_grad: if name in self.ema_params: self.ema_params[name].lerp_(param.data, 1 - self.decay) else: self.ema_params[name] = param.data.clone().detach() def state_dict(self): return self.ema_params def load_state_dict(self, params): self.ema_params = OrderedDict( (name, param.clone().detach().to(self.device)) for name, param in params.items() ) class PaddedDataset(torch.utils.data.Dataset): """Dataset wrapper that pads a dataset to ensure even distribution across processes.""" def __init__(self, dataset, padding_size): self.dataset = dataset self.padding_size = padding_size def __len__(self): return len(self.dataset) + self.padding_size def __getitem__(self, idx): if idx < len(self.dataset): return self.dataset[idx] return self.dataset[0] class CacheDataLoader: """DataLoader-like interface for cached data with epoch-based shuffling.""" def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None): self.slots = slots self.targets = targets self.batch_size = batch_size self.num_augs = num_augs self.seed = seed self.epoch = 0 # Original dataset size (before augmentations) self.num_samples = len(slots) // num_augs def set_epoch(self, epoch): """Set epoch for deterministic shuffling.""" self.epoch = epoch def __len__(self): """Return number of batches based on original dataset size.""" return self.num_samples // self.batch_size def __iter__(self): """Return random indices for current epoch.""" g = torch.Generator() g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch) # Randomly sample indices from the entire augmented dataset indices = torch.randint( 0, len(self.slots), (self.num_samples,), generator=g ).numpy() # Yield batches of indices for start in range(0, self.num_samples, self.batch_size): end = min(start + self.batch_size, self.num_samples) batch_indices = indices[start:end] yield ( torch.from_numpy(self.slots[batch_indices]), torch.from_numpy(self.targets[batch_indices]) )