import logging from PIL import Image import torch import numpy as np def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. """ logging.basicConfig( level=logging.INFO, format='[\033[34m%(asctime)s\033[0m] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] ) logger = logging.getLogger(__name__) return logger @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = dict(ema_model.named_parameters()) for name, param in model.named_parameters(): # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) def requires_grad(model, flag=True): """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) def crop_arr(pil_image, max_image_size): while min(*pil_image.size) >= 2 * max_image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) if max(*pil_image.size) > max_image_size: scale = max_image_size / max(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) if min(*pil_image.size) < 16: scale = 16 / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y1 = (arr.shape[0] % 16) // 2 crop_y2 = arr.shape[0] % 16 - crop_y1 crop_x1 = (arr.shape[1] % 16) // 2 crop_x2 = arr.shape[1] % 16 - crop_x1 arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2] return Image.fromarray(arr) def vae_encode(vae, x, weight_dtype): if x is not None: if vae.config.shift_factor is not None: x = vae.encode(x).latent_dist.sample() x = (x - vae.config.shift_factor) * vae.config.scaling_factor else: x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor) x = x.to(weight_dtype) return x def vae_encode_list(vae, x, weight_dtype): latents = [] for img in x: img = vae_encode(vae, img, weight_dtype) latents.append(img) return latents