Spaces:
Running
Running
import os, torch | |
import cv2 | |
import numpy as np | |
import torch_fidelity | |
from collections import OrderedDict | |
from concurrent.futures import ThreadPoolExecutor | |
import importlib | |
from torch.optim import AdamW | |
from semanticist.utils.lr_scheduler import build_scheduler | |
def get_obj_from_str(string, reload=False): | |
"""Get object from string path.""" | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def instantiate_from_config(config): | |
"""Instantiate an object from a config dictionary.""" | |
if not "target" in config: | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def is_dist_avail_and_initialized(): | |
"""Check if distributed training is available and initialized.""" | |
if not torch.distributed.is_initialized(): | |
return False | |
return True | |
def is_main_process(): | |
"""Check if the current process is the main process.""" | |
return not is_dist_avail_and_initialized() or torch.distributed.get_rank() == 0 | |
def concat_all_gather(tensor): | |
""" | |
Performs all_gather operation on the provided tensors. | |
*** Warning ***: torch.distributed.all_gather has no gradient. | |
""" | |
tensors_gather = [torch.ones_like(tensor) | |
for _ in range(torch.distributed.get_world_size())] | |
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) | |
output = torch.cat(tensors_gather, dim=0) | |
return output | |
def requires_grad(model, flag=True): | |
"""Set requires_grad flag for all model parameters.""" | |
for p in model.parameters(): | |
p.requires_grad = flag | |
def save_img(img, save_path): | |
"""Save a single image to disk.""" | |
img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255) | |
img = img.astype(np.uint8)[:, :, ::-1] | |
cv2.imwrite(save_path, img) | |
def save_img_batch(imgs, save_paths): | |
"""Process and save multiple images at once using a thread pool.""" | |
# Convert to numpy and prepare all images in one go | |
imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8) | |
imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once | |
with ThreadPoolExecutor(max_workers=32) as pool: | |
# Submit all tasks at once | |
futures = [pool.submit(cv2.imwrite, path, img) | |
for path, img in zip(save_paths, imgs)] | |
# Wait for all tasks to complete | |
for future in futures: | |
future.result() # This will raise any exceptions that occurred | |
def get_fid_stats(real_dir, rec_dir, fid_stats): | |
"""Calculate FID statistics between real and reconstructed images.""" | |
stats = torch_fidelity.calculate_metrics( | |
input1=rec_dir, | |
input2=real_dir, | |
fid_statistics_file=fid_stats, | |
cuda=True, | |
isc=True, | |
fid=True, | |
kid=False, | |
prc=False, | |
verbose=False, | |
) | |
return stats | |
def create_scheduler(optimizer, num_epoch, steps_per_epoch, lr_min, warmup_steps, | |
warmup_lr_init, decay_steps, cosine_lr): | |
"""Create a learning rate scheduler.""" | |
scheduler = build_scheduler( | |
optimizer, | |
num_epoch, | |
steps_per_epoch, | |
lr_min, | |
warmup_steps, | |
warmup_lr_init, | |
decay_steps, | |
cosine_lr, | |
) | |
return scheduler | |
def load_state_dict(state_dict, model): | |
"""Helper to load a state dict with proper prefix handling.""" | |
if 'state_dict' in state_dict: | |
state_dict = state_dict['state_dict'] | |
# Remove '_orig_mod' prefix if present | |
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} | |
missing, unexpected = model.load_state_dict( | |
state_dict, strict=False | |
) | |
if is_main_process(): | |
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") | |
def load_safetensors(path, model): | |
"""Helper to load a safetensors checkpoint.""" | |
from safetensors.torch import safe_open | |
with safe_open(path, framework="pt", device="cpu") as f: | |
state_dict = {k: f.get_tensor(k) for k in f.keys()} | |
load_state_dict(state_dict, model) | |
def setup_result_folders(result_folder): | |
"""Setup result folders for saving models and images.""" | |
model_saved_dir = os.path.join(result_folder, "models") | |
os.makedirs(model_saved_dir, exist_ok=True) | |
image_saved_dir = os.path.join(result_folder, "images") | |
os.makedirs(image_saved_dir, exist_ok=True) | |
return model_saved_dir, image_saved_dir | |
def create_optimizer(model, weight_decay, learning_rate, betas=(0.9, 0.95)): | |
"""Create an AdamW optimizer with weight decay for 2D parameters only.""" | |
# start with all of the candidate parameters | |
param_dict = {pn: p for pn, p in model.named_parameters()} | |
# filter out those that do not require grad | |
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} | |
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. | |
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. | |
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] | |
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] | |
optim_groups = [ | |
{'params': decay_params, 'weight_decay': weight_decay}, | |
{'params': nodecay_params, 'weight_decay': 0.0} | |
] | |
num_decay_params = sum(p.numel() for p in decay_params) | |
num_nodecay_params = sum(p.numel() for p in nodecay_params) | |
if is_main_process(): | |
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") | |
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") | |
optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas) | |
return optimizer | |
class EMAModel: | |
"""Model Exponential Moving Average.""" | |
def __init__(self, model, device, decay=0.999): | |
self.device = device | |
self.decay = decay | |
self.ema_params = OrderedDict( | |
(name, param.clone().detach().to(device)) | |
for name, param in model.named_parameters() | |
if param.requires_grad | |
) | |
def update(self, model): | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
if name in self.ema_params: | |
self.ema_params[name].lerp_(param.data, 1 - self.decay) | |
else: | |
self.ema_params[name] = param.data.clone().detach() | |
def state_dict(self): | |
return self.ema_params | |
def load_state_dict(self, params): | |
self.ema_params = OrderedDict( | |
(name, param.clone().detach().to(self.device)) | |
for name, param in params.items() | |
) | |
class PaddedDataset(torch.utils.data.Dataset): | |
"""Dataset wrapper that pads a dataset to ensure even distribution across processes.""" | |
def __init__(self, dataset, padding_size): | |
self.dataset = dataset | |
self.padding_size = padding_size | |
def __len__(self): | |
return len(self.dataset) + self.padding_size | |
def __getitem__(self, idx): | |
if idx < len(self.dataset): | |
return self.dataset[idx] | |
return self.dataset[0] | |
class CacheDataLoader: | |
"""DataLoader-like interface for cached data with epoch-based shuffling.""" | |
def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None): | |
self.slots = slots | |
self.targets = targets | |
self.batch_size = batch_size | |
self.num_augs = num_augs | |
self.seed = seed | |
self.epoch = 0 | |
# Original dataset size (before augmentations) | |
self.num_samples = len(slots) // num_augs | |
def set_epoch(self, epoch): | |
"""Set epoch for deterministic shuffling.""" | |
self.epoch = epoch | |
def __len__(self): | |
"""Return number of batches based on original dataset size.""" | |
return self.num_samples // self.batch_size | |
def __iter__(self): | |
"""Return random indices for current epoch.""" | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch) | |
# Randomly sample indices from the entire augmented dataset | |
indices = torch.randint( | |
0, len(self.slots), | |
(self.num_samples,), | |
generator=g | |
).numpy() | |
# Yield batches of indices | |
for start in range(0, self.num_samples, self.batch_size): | |
end = min(start + self.batch_size, self.num_samples) | |
batch_indices = indices[start:end] | |
yield ( | |
torch.from_numpy(self.slots[batch_indices]), | |
torch.from_numpy(self.targets[batch_indices]) | |
) |