diff --git a/Control-Color/CtrlColor_environ.yaml b/Control-Color/CtrlColor_environ.yaml deleted file mode 100644 index e6bc63e985c7c9634efbc17519f489e52436c86a..0000000000000000000000000000000000000000 --- a/Control-Color/CtrlColor_environ.yaml +++ /dev/null @@ -1,40 +0,0 @@ -name: CtrlColor -channels: - - pytorch - - defaults -dependencies: - - python=3.8.5 - - pip=20.3 - - cudatoolkit=11.3 - - pytorch=1.12.1 - - torchvision=0.13.1 - - numpy=1.23.1 - - pip: - - gradio==3.31.0 - - gradio-client==0.2.5 - - albumentations==1.3.0 - - opencv-python==4.9.0.80 - - opencv-python-headless==4.5.5.64 - - imageio==2.9.0 - - imageio-ffmpeg==0.4.2 - - pytorch-lightning==1.5.0 - - omegaconf==2.1.1 - - test-tube>=0.7.5 - - streamlit==1.12.1 - - webdataset==0.2.5 - - kornia==0.6 - - open_clip_torch==2.0.2 - - invisible-watermark>=0.1.5 - - streamlit-drawable-canvas==0.8.0 - - torchmetrics==0.6.0 - - addict==2.4.0 - - yapf==0.32.0 - - prettytable==3.6.0 - - basicsr==1.4.2 - - salesforce-lavis==1.0.2 - - grpcio==1.60 - - pydantic==1.10.5 - - spacy==3.5.1 - - typer==0.7.0 - - typing-extensions==4.4.0 - - fastapi==0.92.0 \ No newline at end of file diff --git a/Control-Color/annotator/__pycache__/util.cpython-38.pyc b/Control-Color/annotator/__pycache__/util.cpython-38.pyc deleted file mode 100644 index c4c4b1cdfc704d74b0b1c70d721b3b9878d84abb..0000000000000000000000000000000000000000 Binary files a/Control-Color/annotator/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/annotator/util.py b/Control-Color/annotator/util.py deleted file mode 100644 index 10e8f0cb6bd9deeff8995b2f72be2e4ea6df343e..0000000000000000000000000000000000000000 --- a/Control-Color/annotator/util.py +++ /dev/null @@ -1,40 +0,0 @@ -import numpy as np -import cv2 -import os - - -annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') - - -def HWC3(x): - assert x.dtype == np.uint8 - if x.ndim == 2: - x = x[:, :, None] - assert x.ndim == 3 - H, W, C = x.shape - assert C == 1 or C == 3 or C == 4 - if C == 3: - return x - if C == 1: - return np.concatenate([x, x, x], axis=2) - if C == 4: - color = x[:, :, 0:3].astype(np.float32) - alpha = x[:, :, 3:4].astype(np.float32) / 255.0 - y = color * alpha + 255.0 * (1.0 - alpha) - y = y.clip(0, 255).astype(np.uint8) - return y - - -def resize_image(input_image, resolution): - H, W, C = input_image.shape - H = float(H) - W = float(W) - k = float(resolution) / min(H, W)#min(H,W) - H *= k - W *= k - H_new = int(np.round(H / 64.0)) * 64 - W_new = int(np.round(W / 64.0)) * 64 - H = H_new if H_new<800 else int(np.round(800 / 64.0)) * 64#1024->896 - W=W_new if W_new<800 else int(np.round(800 / 64.0)) * 64 - img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) - return img diff --git a/Control-Color/app.py b/Control-Color/app.py deleted file mode 100644 index af33adb5f1a6fa0cef952b1ccb32dafdb06f0230..0000000000000000000000000000000000000000 --- a/Control-Color/app.py +++ /dev/null @@ -1,524 +0,0 @@ -import os -from share import * -import config - -import cv2 -import einops -import gradio as gr -import numpy as np -import torch -import random - -from pytorch_lightning import seed_everything -from annotator.util import resize_image -from cldm.model import create_model, load_state_dict -from cldm.ddim_haced_sag_step import DDIMSampler -from lavis.models import load_model_and_preprocess -from PIL import Image -import tqdm - -from ldm.models.autoencoder_train import AutoencoderKL - -ckpt_path="./pretrained_models/main_model.ckpt" - -model = create_model('./models/cldm_v15_inpainting_infer1.yaml').cpu() -model.load_state_dict(load_state_dict(ckpt_path, location='cuda'),strict=False) -model = model.cuda() - -ddim_sampler = DDIMSampler(model) - - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -BLIP_model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=device) - -vae_model_ckpt_path="./pretrained_models/content-guided_deformable_vae.ckpt" - -def load_vae(): - init_config = { - "embed_dim": 4, - "monitor": "val/rec_loss", - "ddconfig":{ - "double_z": True, - "z_channels": 4, - "resolution": 256, - "in_channels": 3, - "out_ch": 3, - "ch": 128, - "ch_mult":[1,2,4,4], - "num_res_blocks": 2, - "attn_resolutions": [], - "dropout": 0.0, - }, - "lossconfig":{ - "target": "ldm.modules.losses.LPIPSWithDiscriminator", - "params":{ - "disc_start": 501, - "kl_weight": 0, - "disc_weight": 0.025, - "disc_factor": 1.0 - } - } - } - vae = AutoencoderKL(**init_config) - vae.load_state_dict(load_state_dict(vae_model_ckpt_path, location='cuda')) - vae = vae.cuda() - return vae - -vae_model=load_vae() - -def encode_mask(mask,masked_image): - mask = torch.nn.functional.interpolate(mask, size=(mask.shape[2] // 8, mask.shape[3] // 8)) - # mask=torch.cat([mask] * 2) #if do_classifier_free_guidance else mask - mask = mask.to(device="cuda") - # do_classifier_free_guidance=False - masked_image_latents = model.get_first_stage_encoding(model.encode_first_stage(masked_image.cuda())).detach() - return mask,masked_image_latents - -def get_mask(input_image,hint_image): - mask=input_image.copy() - H,W,C=input_image.shape - for i in range(H): - for j in range(W): - if input_image[i,j,0]==hint_image[i,j,0]: - # print(input_image[i,j,0]) - mask[i,j,:]=255. - else: - mask[i,j,:]=0. #input_image[i,j,:] - kernel=cv2.getStructuringElement(cv2.MORPH_RECT,(3,3)) - mask=cv2.morphologyEx(np.array(mask),cv2.MORPH_OPEN,kernel,iterations=1) - return mask - -def prepare_mask_and_masked_image(image, mask): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") - - # Batch single image - if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - # preprocess image - if isinstance(image, (Image.Image, np.ndarray)): - image = [image] - - if isinstance(image, list) and isinstance(image[0], Image.Image): - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], Image.Image): - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - return mask, masked_image - -# generate image -generator = torch.manual_seed(859311133)#0 -def path2L(img_path): - raw_image = cv2.imread(img_path) - raw_image = cv2.cvtColor(raw_image,cv2.COLOR_BGR2LAB) - raw_image_input = cv2.merge([raw_image[:,:,0],raw_image[:,:,0],raw_image[:,:,0]]) - return raw_image_input - -def is_gray_scale(img, threshold=10): - img = Image.fromarray(img) - if len(img.getbands()) == 1: - return True - img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16) - img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16) - img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16) - diff1 = (img1 - img2).var() - diff2 = (img2 - img3).var() - diff3 = (img3 - img1).var() - diff_sum = (diff1 + diff2 + diff3) / 3.0 - if diff_sum <= threshold: - return True - else: - return False - -def randn_tensor( - shape, - generator= None, - device= None, - dtype=None, - layout= None, -): - """A helper function to create random tensors on the desired `device` with the desired `dtype`. When - passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor - is always created on the CPU. - """ - # device on which tensor is created defaults to device - rand_device = device - batch_size = shape[0] - - layout = layout or torch.strided - device = device or torch.device("cpu") - - if generator is not None: - gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type - if gen_device_type != device.type and gen_device_type == "cpu": - rand_device = "cpu" - if device != "mps": - print("The passed generator was created on 'cpu' even though a tensor on {device} was expected.") - # logger.info( - # f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." - # f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" - # f" slighly speed up this function by passing a generator that was created on the {device} device." - # ) - elif gen_device_type != device.type and gen_device_type == "cuda": - raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") - - # make sure generator list of length 1 is treated like a non-list - if isinstance(generator, list) and len(generator) == 1: - generator = generator[0] - - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - - return latents - -def add_noise( - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - betas = torch.linspace(0.00085, 0.0120, 1000, dtype=torch.float32) - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - - return noisy_samples - -def set_timesteps(num_inference_steps: int, timestep_spacing="leading",device=None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - num_train_timesteps=1000 - if num_inference_steps > num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" - f" {num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {num_train_timesteps} timesteps." - ) - - num_inference_steps = num_inference_steps - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, num_train_timesteps - 1, num_inference_steps) - .round()[::-1] - .copy() - .astype(np.int64) - ) - elif timestep_spacing == "leading": - step_ratio = num_train_timesteps // num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) - # timesteps += steps_offset - elif timestep_spacing == "trailing": - step_ratio = num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.round(np.arange(num_train_timesteps, 0, -step_ratio)).astype(np.int64) - timesteps -= 1 - else: - raise ValueError( - f"{timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." - ) - - timesteps = torch.from_numpy(timesteps).to(device) - return timesteps - -def get_timesteps(num_inference_steps, timesteps_set, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = timesteps_set[t_start * 1 :] - - return timesteps, num_inference_steps - t_start - - -def get_noised_image_latents(img,W,H,ddim_steps,strength,seed,device): - img1 = [cv2.resize(img,(W,H))] - img1 = np.concatenate([i[None, :] for i in img1], axis=0) - img1 = img1.transpose(0, 3, 1, 2) - img1 = torch.from_numpy(img1).to(dtype=torch.float32) /127.5 - 1.0 - - image_latents=model.get_first_stage_encoding(model.encode_first_stage(img1.cuda())).detach() - shape=image_latents.shape - generator = torch.manual_seed(seed) - - noise = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) - - timesteps_set=set_timesteps(ddim_steps,timestep_spacing="linspace", device=device) - timesteps, num_inference_steps = get_timesteps(ddim_steps, timesteps_set, strength, device) - latent_timestep = timesteps[1].repeat(1 * 1) - - init_latents = add_noise(image_latents, noise, torch.tensor(latent_timestep)) - for j in range(0, 1000, 100): - - x_samples=model.decode_first_stage(add_noise(image_latents, noise, torch.tensor(j))) - init_image=(einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) - - cv2.imwrite("./initlatents1/"+str(j)+"init_image.png",cv2.cvtColor(init_image[0],cv2.COLOR_RGB2BGR)) - return init_latents - -def process(using_deformable_vae,change_according_to_strokes,iterative_editing,input_image,hint_image,prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, sag_scale,SAG_influence_step, seed, eta): - torch.cuda.empty_cache() - with torch.no_grad(): - ref_flag=True - input_image_ori=input_image - if is_gray_scale(input_image): - print("It is a greyscale image.") - # mask=get_mask(input_image,hint_image) - else: - print("It is a color image.") - input_image_ori=input_image - input_image=cv2.cvtColor(input_image,cv2.COLOR_RGB2LAB)[:,:,0] - input_image=cv2.merge([input_image,input_image,input_image]) - mask=get_mask(input_image_ori,hint_image) - cv2.imwrite("gradio_mask1.png",mask) - - if iterative_editing: - mask=255-mask - if change_according_to_strokes: - hint_image=mask/255.*hint_image+(1-mask/255.)*input_image_ori - else: - hint_image=mask/255.*input_image+(1-mask/255.)*input_image_ori - else: - hint_image=mask/255.*input_image+(1-mask/255.)*hint_image - hint_image=hint_image.astype(np.uint8) - if len(prompt)==0: - image = Image.fromarray(input_image) - image = vis_processors["eval"](image).unsqueeze(0).to(device) - prompt = BLIP_model.generate({"image": image})[0] - if "a black and white photo of" in prompt or "black and white photograph of" in prompt: - prompt=prompt.replace(prompt[:prompt.find("of")+3],"") - print(prompt) - H_ori,W_ori,C_ori=input_image.shape - img = resize_image(input_image, image_resolution) - mask = resize_image(mask, image_resolution) - hint_image =resize_image(hint_image,image_resolution) - mask,masked_image=prepare_mask_and_masked_image(Image.fromarray(hint_image),Image.fromarray(mask)) - mask,masked_image_latents=encode_mask(mask,masked_image) - H, W, C = img.shape - - # if ref_image is None: - ref_image=np.array([[[0]*C]*W]*H).astype(np.float32) - # print(ref_image.shape) - # ref_flag=False - ref_image=resize_image(ref_image,image_resolution) - - # cv2.imwrite("exemplar_image.png",cv2.cvtColor(ref_image,cv2.COLOR_RGB2BGR)) - - # ddim_steps=1 - control = torch.from_numpy(img.copy()).float().cuda() / 255.0 - control = torch.stack([control for _ in range(num_samples)], dim=0) - control = einops.rearrange(control, 'b h w c -> b c h w').clone() - - if seed == -1: - seed = random.randint(0, 65535) - seed_everything(seed) - - ref_image=cv2.resize(ref_image,(W,H)) - - ref_image=torch.from_numpy(ref_image).cuda().unsqueeze(0) - - init_latents=None - - if config.save_memory: - model.low_vram_shift(is_diffusing=False) - - print("no reference images, using Frozen encoder") - cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} - un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} - shape = (4, H // 8, W // 8) - - if config.save_memory: - model.low_vram_shift(is_diffusing=True) - noise = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) - model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 - samples, intermediates = ddim_sampler.sample(model,ddim_steps, num_samples, - shape, cond, mask=mask, masked_image_latents=masked_image_latents,verbose=False, eta=eta, - # x_T=image_latents, - x_T=init_latents, - unconditional_guidance_scale=scale, - sag_scale = sag_scale, - SAG_influence_step=SAG_influence_step, - noise = noise, - unconditional_conditioning=un_cond) - - - if config.save_memory: - model.low_vram_shift(is_diffusing=False) - - if not using_deformable_vae: - x_samples = model.decode_first_stage(samples) - else: - samples = model.decode_first_stage_before_vae(samples) - gray_content_z=vae_model.get_gray_content_z(torch.from_numpy(img.copy()).float().cuda() / 255.0) - # print(gray_content_z.shape) - x_samples = vae_model.decode(samples,gray_content_z) - - x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) - - #single image replace L channel - results_ori = [x_samples[i] for i in range(num_samples)] - results_ori=[cv2.resize(i,(W_ori,H_ori),interpolation=cv2.INTER_LANCZOS4) for i in results_ori] - - cv2.imwrite("result_ori.png",cv2.cvtColor(results_ori[0],cv2.COLOR_RGB2BGR)) - - results_tmp=[cv2.cvtColor(np.array(i),cv2.COLOR_RGB2LAB) for i in results_ori] - results=[cv2.merge([input_image[:,:,0],tmp[:,:,1],tmp[:,:,2]]) for tmp in results_tmp] - results_mergeL=[cv2.cvtColor(np.asarray(i),cv2.COLOR_LAB2RGB) for i in results]#cv2.COLOR_LAB2BGR) - cv2.imwrite("output.png",cv2.cvtColor(results_mergeL[0],cv2.COLOR_RGB2BGR)) - return results_mergeL - -def get_grayscale_img(img, progress=gr.Progress(track_tqdm=True)): - torch.cuda.empty_cache() - for j in tqdm.tqdm(range(1),desc="Uploading input..."): - return img,"Uploading input image done." - -block = gr.Blocks().queue() -with block: - with gr.Row(): - gr.Markdown("## Control-Color")#("## Color-Anything")#Control Stable Diffusion with L channel - with gr.Row(): - with gr.Column(): - # input_image = gr.Image(source='upload', type="numpy") - grayscale_img = gr.Image(visible=False, type="numpy") - input_image = gr.Image(source='upload',tool='color-sketch',interactive=True) - Grayscale_button = gr.Button(value="Upload input image") - text_out = gr.Textbox(value="Please upload input image first, then draw the strokes or input text prompts or give reference images as you wish.") - prompt = gr.Textbox(label="Prompt") - change_according_to_strokes = gr.Checkbox(label='Change according to strokes\' color', value=True) - iterative_editing = gr.Checkbox(label='Only change the strokes\' area', value=False) - using_deformable_vae = gr.Checkbox(label='Using deformable vae. (Less color overflow)', value=False) - # with gr.Accordion("Input Reference", open=False): - # ref_image = gr.Image(source='upload', type="numpy") - run_button = gr.Button(label="Upload prompts/strokes (optional) and Run",value="Upload prompts/strokes (optional) and Run") - with gr.Accordion("Advanced options", open=False): - num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) - image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) - strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) - guess_mode = gr.Checkbox(label='Guess Mode', value=False) - #detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1) - ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) - scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.0, step=0.1)#value=9.0 - sag_scale = gr.Slider(label="SAG Scale", minimum=0.0, maximum=1.0, value=0.05, step=0.01)#0.08 - SAG_influence_step = gr.Slider(label="1000-SAG influence step", minimum=0, maximum=900, value=600, step=50) - seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)#94433242802 - eta = gr.Number(label="eta (DDIM)", value=0.0) - a_prompt = gr.Textbox(label="Added Prompt", value='best quality, detailed, real')#extremely detailed - n_prompt = gr.Textbox(label="Negative Prompt", - value='a black and white photo, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') - with gr.Column(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') - # grayscale_img = gr.Image(interactive=False,visible=False) - - Grayscale_button.click(fn=get_grayscale_img,inputs=input_image,outputs=[grayscale_img,text_out]) - ips = [using_deformable_vae,change_according_to_strokes,iterative_editing,grayscale_img,input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale,sag_scale,SAG_influence_step, seed, eta] - run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) - - -block.launch(server_name='0.0.0.0',share=True) diff --git a/Control-Color/cldm/__pycache__/cldm.cpython-38.pyc b/Control-Color/cldm/__pycache__/cldm.cpython-38.pyc deleted file mode 100644 index 8b474b1403294133f55f1a1e4aa6bff8dbd25f17..0000000000000000000000000000000000000000 Binary files a/Control-Color/cldm/__pycache__/cldm.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/cldm/__pycache__/ddim_haced_sag_step.cpython-38.pyc b/Control-Color/cldm/__pycache__/ddim_haced_sag_step.cpython-38.pyc deleted file mode 100644 index 61d721e782ae2b56589687980d0afc916a8dd6cc..0000000000000000000000000000000000000000 Binary files a/Control-Color/cldm/__pycache__/ddim_haced_sag_step.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/cldm/__pycache__/hack.cpython-310.pyc b/Control-Color/cldm/__pycache__/hack.cpython-310.pyc deleted file mode 100644 index 85434e64807fcea5eb7626407285128b18062603..0000000000000000000000000000000000000000 Binary files a/Control-Color/cldm/__pycache__/hack.cpython-310.pyc and /dev/null differ diff --git a/Control-Color/cldm/__pycache__/hack.cpython-38.pyc b/Control-Color/cldm/__pycache__/hack.cpython-38.pyc deleted file mode 100644 index f01b79c8798e468df4b471dd5ce7781b1fc18532..0000000000000000000000000000000000000000 Binary files a/Control-Color/cldm/__pycache__/hack.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/cldm/__pycache__/model.cpython-38.pyc b/Control-Color/cldm/__pycache__/model.cpython-38.pyc deleted file mode 100644 index 65e96c4d542cdaeb5e4f865eb2ed45432c7161b8..0000000000000000000000000000000000000000 Binary files a/Control-Color/cldm/__pycache__/model.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/cldm/cldm.py b/Control-Color/cldm/cldm.py deleted file mode 100644 index b1492cb52c3aaf924dfa6d213b330f881fd98fb0..0000000000000000000000000000000000000000 --- a/Control-Color/cldm/cldm.py +++ /dev/null @@ -1,547 +0,0 @@ -import einops -import torch -import torch as th -import torch.nn as nn - -from ldm.modules.diffusionmodules.util import ( - conv_nd, - linear, - zero_module, - timestep_embedding, -) - -from einops import rearrange, repeat -from torchvision.utils import make_grid -from ldm.modules.attention import SpatialTransformer -from ldm.modules.attention_dcn_control import SpatialTransformer_dcn -from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.util import log_txt_as_img, exists, instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler - - -class ControlledUnetModel(UNetModel): - def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): - hs = [] - # print("timestep",timesteps) - with torch.no_grad(): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - # print("t_emb",t_emb) - emb = self.time_embed(t_emb) - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context)#,timestep=timesteps) - hs.append(h) - h = self.middle_block(h, emb, context)#,timestep=timesteps) - - if control is not None: - h += control.pop() - - for i, module in enumerate(self.output_blocks): - # print("output_blocks0",h.shape) - if only_mid_control or control is None: - h = torch.cat([h, hs.pop()], dim=1) - else: - h = torch.cat([h, hs.pop() + control.pop()], dim=1) - h = module(h, emb, context)#,timestep=timesteps) - - # print("output_blocks",h.shape) - - h = h.type(x.dtype) - h=self.out(h) - # print("self.ot",h.shape) - return h - - -class ControlNet(nn.Module): - def __init__( - self, - image_size, - in_channels, - model_channels, - hint_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None, - disable_middle_self_attn=False, - use_linear_in_transformer=False, - ): - super().__init__() - if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' - - if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' - from omegaconf.listconfig import ListConfig - if type(context_dim) == ListConfig: - context_dim = list(context_dim) - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' - - if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - - self.dims = dims - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - if isinstance(num_res_blocks, int): - self.num_res_blocks = len(channel_mult) * [num_res_blocks] - else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") - self.num_res_blocks = num_res_blocks - if disable_self_attentions is not None: - # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not - assert len(disable_self_attentions) == len(channel_mult) - if num_attention_blocks is not None: - assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) - - self.input_hint_block = TimestepEmbedSequential( - conv_nd(dims, hint_channels, 16, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 16, 16, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 16, 32, 3, padding=1, stride=2), - nn.SiLU(), - conv_nd(dims, 32, 32, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 32, 96, 3, padding=1, stride=2), - nn.SiLU(), - conv_nd(dims, 96, 96, 3, padding=1), - nn.SiLU(), - conv_nd(dims, 96, 256, 3, padding=1, stride=2), - nn.SiLU(), - zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) - ) - - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for nr in range(self.num_res_blocks[level]): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - if exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch)) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self.middle_block_out = self.make_zero_conv(ch) - self._feature_size += ch - - def make_zero_conv(self, channels): - return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) - - def forward(self, x, hint, timesteps, context, **kwargs): - # print("cldm",hint.shape,x.shape) - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - guided_hint = self.input_hint_block(hint, emb, context) - - outs = [] - - h = x.type(self.dtype) - # h_in=h - - for module, zero_conv in zip(self.input_blocks, self.zero_convs): - if guided_hint is not None: - h = module(h, emb, context)#,dcn_guide=h_in) - h += guided_hint - guided_hint = None - else: - # print("dcn_guide") - h = module(h, emb, context)#,dcn_guide=h_in) - outs.append(zero_conv(h, emb, context)) - - h = self.middle_block(h, emb, context)#,dcn_guide=h_in) - outs.append(self.middle_block_out(h, emb, context)) - - return outs - - -class ControlLDM(LatentDiffusion): - - def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs): #freeze - # print(control_stage_config) - super().__init__(*args, **kwargs) - self.control_model = instantiate_from_config(control_stage_config) - self.control_key = control_key - self.only_mid_control = only_mid_control - self.control_scales = [1.0] * 13 - # if freeze==True: - # self.freeze() - - # def freeze(self): - # #self.train = disabled_train - # for param in self.parameters(): - # param.requires_grad = False - - - - @torch.no_grad() - def get_input(self, batch, k, bs=None, *args, **kwargs): - x,mask,masked_image_latents, c = super().get_input(batch, self.first_stage_key, *args, **kwargs) - control = batch[self.control_key] - if bs is not None: - control = control[:bs] - control = control.to(self.device) - control = einops.rearrange(control, 'b h w c -> b c h w') - control = control.to(memory_format=torch.contiguous_format).float() - return x,mask,masked_image_latents, dict(c_crossattn=[c], c_concat=[control]) - - def apply_model(self, x_noisy,mask,masked_image_latents, t, cond, *args, **kwargs): - assert isinstance(cond, dict) - diffusion_model = self.model.diffusion_model - - cond_txt = torch.cat(cond['c_crossattn'], 1) - # print(cond_txt.shape,cond['c_crossattn'].shape) - if cond['c_concat'] is None: - eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control) - else: - control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt) - control = [c * scale for c, scale in zip(control, self.control_scales)] - mask=torch.cat([mask] * x_noisy.shape[0]) - masked_image_latents=torch.cat([masked_image_latents] * x_noisy.shape[0]) - x_noisy = torch.cat([x_noisy,mask,masked_image_latents], dim=1) - eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control) - - return eps - - def apply_model_addhint(self, x_noisy,mask,masked_image_latents, t, cond, *args, **kwargs): - assert isinstance(cond, dict) - diffusion_model = self.model.diffusion_model - - cond_txt = torch.cat(cond['c_crossattn'], 1) - # print(cond_txt.shape,cond['c_crossattn'].shape) - if cond['c_concat'] is None: - eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control) - else: - control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt) - control = [c * scale for c, scale in zip(control, self.control_scales)] - # print(x_noisy.shape,mask.shape,masked_image_latents.shape) - x_noisy = torch.cat([x_noisy,mask,masked_image_latents], dim=1) - eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control) - - return eps - - @torch.no_grad() - def get_unconditional_conditioning(self, N): - return self.get_learned_conditioning([""] * N) - # def get_unconditional_conditioning(self, N,hint_image): - # hint_image[:,:,:,:]=0 - # return self.get_learned_conditioning(([""] * N,hint_image)) - - # @torch.no_grad() - # def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None, - # quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - # plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None, - # use_ema_scope=True, - # **kwargs): - # use_ddim = ddim_steps is not None - - # log = dict() - # z,mask,masked_image_latents, c = self.get_input(batch, self.first_stage_key, bs=N) - # c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N] - # N = min(z.shape[0], N) - # n_row = min(z.shape[0], n_row) - # log["reconstruction"] = self.decode_first_stage(z) - # log["control"] = c_cat * 2.0 - 1.0 - # log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) - # txt,hint_image=batch[self.cond_stage_key] - # if plot_diffusion_rows: - # # get diffusion row - # diffusion_row = list() - # z_start = z[:n_row] - # for t in range(self.num_timesteps): - # if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - # t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - # t = t.to(self.device).long() - # noise = torch.randn_like(z_start) - # z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - # diffusion_row.append(self.decode_first_stage(z_noisy)) - - # diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - # diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - # diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - # diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - # log["diffusion_row"] = diffusion_grid - - # if sample: - # # get denoise row - # samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - # batch_size=N, ddim=use_ddim, - # ddim_steps=ddim_steps, eta=ddim_eta) - # x_samples = self.decode_first_stage(samples) - # log["samples"] = x_samples - # if plot_denoise_rows: - # denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - # log["denoise_row"] = denoise_grid - - # if unconditional_guidance_scale > 1.0: - # uc_cross = self.get_unconditional_conditioning(N,hint_image) - # uc_cat = c_cat # torch.zeros_like(c_cat) - # uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} - # samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - # batch_size=N, ddim=use_ddim, - # ddim_steps=ddim_steps, eta=ddim_eta, - # unconditional_guidance_scale=unconditional_guidance_scale, - # unconditional_conditioning=uc_full, - # ) - # x_samples_cfg = self.decode_first_stage(samples_cfg) - # log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg - - # return log - - @torch.no_grad() - def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): - use_ddim = ddim_steps is not None - - log = dict() - z,mask,masked_image_latents, c = self.get_input(batch, self.first_stage_key, bs=N, ) - c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N] - N = min(z.shape[0], N) - n_row = min(z.shape[0], n_row) - log["reconstruction"] = self.decode_first_stage(z) - log["control"] = c_cat * 2.0 - 1.0 - log["conditioning"] = log_txt_as_img((512, 512),batch[self.masked_image], batch[self.cond_stage_key], size=16) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = list() - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid - - if sample: - # get denoise row - samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},mask=mask,masked_image_latents=masked_image_latents, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) - x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - if unconditional_guidance_scale > 1.0: - uc_cross = self.get_unconditional_conditioning(N) - uc_cat = c_cat # torch.zeros_like(c_cat) - uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} - samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},mask=mask,masked_image_latents=masked_image_latents, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc_full, - ) - x_samples_cfg = self.decode_first_stage(samples_cfg) - log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg - - return log - @torch.no_grad() - def sample_log(self, cond,mask,masked_image_latents, batch_size, ddim, ddim_steps, **kwargs): - ddim_sampler = DDIMSampler(self) - b, c, h, w = cond["c_concat"][0].shape - shape = (self.channels, h // 8, w // 8) - samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond,mask=mask,masked_image_latents=masked_image_latents, verbose=False, **kwargs) - return samples, intermediates - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.control_model.parameters()) - # head_params=list() - # for name,param in self.control_model.named_parameters(): #self.model.named_parameters(): - # if "dcn" in name: - # # print(name) - # head_params.append(param) - # # params = list(self.control_model.parameters())+head_params - # params = head_params - if not self.sd_locked: - params += list(self.model.diffusion_model.output_blocks.parameters()) - params += list(self.model.diffusion_model.out.parameters()) - opt = torch.optim.AdamW(params, lr=lr) - return opt - - def low_vram_shift(self, is_diffusing): - if is_diffusing: - self.model = self.model.cuda() - self.control_model = self.control_model.cuda() - self.first_stage_model = self.first_stage_model.cpu() - self.cond_stage_model = self.cond_stage_model.cpu() - else: - self.model = self.model.cpu() - self.control_model = self.control_model.cpu() - self.first_stage_model = self.first_stage_model.cuda() - self.cond_stage_model = self.cond_stage_model.cuda() diff --git a/Control-Color/cldm/ddim_haced_sag_step.py b/Control-Color/cldm/ddim_haced_sag_step.py deleted file mode 100644 index 1bcca156320490717f7549e6aa4281d6cecc7116..0000000000000000000000000000000000000000 --- a/Control-Color/cldm/ddim_haced_sag_step.py +++ /dev/null @@ -1,494 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor -import torch.nn.functional as F - -import cv2 - -import einops -# Gaussian blur -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = F.pad(img, padding, mode="reflect") - img = F.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img - -# processes and stores attention probabilities -class CrossAttnStoreProcessor: - def __init__(self): - self.attention_probs = None - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - self.attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(self.attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - model, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - masked_image_latents=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - sag_scale=0.75, - SAG_influence_step=600, - noise = None, - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - # print(shape) - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(model,conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask,masked_image_latents=masked_image_latents, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - sag_scale = sag_scale, - SAG_influence_step = SAG_influence_step, - noise = noise, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) - return samples, intermediates - - def add_noise(self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - betas = torch.linspace(0.00085, 0.0120, 1000, dtype=torch.float32) - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - - return noisy_samples - - - def sag_masking(self, original_latents,model_output,x, attn_map, map_size, t, eps): - # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf - bh, hw1, hw2 = attn_map.shape - b, latent_channel, latent_h, latent_w = original_latents.shape - h = 4 #self.unet.config.attention_head_dim - if isinstance(h, list): - h = h[-1] - attn_map = attn_map.reshape(b, h, hw1, hw2) - attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 - attn_mask = ( - attn_mask.reshape(b, map_size[0], map_size[1]) - .unsqueeze(1) - .repeat(1, latent_channel, 1, 1) - .type(attn_map.dtype) - ) - attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) - degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) - degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) #x#original_latents - - return degraded_latents - - def pred_epsilon(self, sample, model_output, timestep): - alpha_prod_t = timestep - - beta_prod_t = 1 - alpha_prod_t - # print(self.model.parameterization)#eps - if self.model.parameterization == "eps": - pred_eps = model_output - elif self.model.parameterization == "sample": - pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) - elif self.model.parameterization == "v": - pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output - else: - raise ValueError( - f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `eps`, `sample`," - " or `v`" - ) - - return pred_eps - - @torch.no_grad() - def ddim_sampling(self,model, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None,masked_image_latents=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1.,sag_scale = 0.75, SAG_influence_step=600, sag_enable = True, noise = None, unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - # timesteps =100 - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - # timesteps=timesteps[:-3] - # print("timesteps",timesteps) - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) - - for i, step in enumerate(iterator): - # print(step) - if step > SAG_influence_step: - sag_enable_t=True - else: - sag_enable_t=False - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if ucg_schedule is not None: - assert len(ucg_schedule) == len(time_range) - unconditional_guidance_scale = ucg_schedule[i] - - outs = self.p_sample_ddim(img,mask,masked_image_latents, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - sag_scale = sag_scale, - sag_enable=sag_enable_t, - noise =noise, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) - img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - x_samples = model.decode_first_stage(img) - x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) - - #single image replace L channel - results_ori = [x_samples[i] for i in range(1)] - # results_ori=[i for i in results_ori] - - # cv2.imwrite("result_ori"+str(step)+".png",cv2.cvtColor(results_ori[0],cv2.COLOR_RGB2BGR)) - return img, intermediates - - @torch.no_grad() - def p_sample_ddim(self, x,mask,masked_image_latents, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1.,sag_scale = 0.75, sag_enable=True, noise=None, unconditional_conditioning=None, - dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - model_output = self.model.apply_model(x,mask,masked_image_latents, t, c) - else: - model_t = self.model.apply_model(x,mask,masked_image_latents, t, c) - model_uncond = self.model.apply_model(x,mask,masked_image_latents, t, unconditional_conditioning) - model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) - - if self.model.parameterization == "v": - e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) - else: - e_t = model_output - - if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - if sag_enable == True: - uncond_attn, cond_attn = self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1.attention_probs.chunk(2) - # self-attention-based degrading of latents - map_size = self.model.model.diffusion_model.middle_block[1].map_size - degraded_latents = self.sag_masking( - pred_x0,model_output,x,uncond_attn, map_size, t, eps = noise, #self.pred_epsilon(x, model_uncond, self.model.alphas_cumprod[t]),#noise - ) - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - degraded_model_output = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c) - else: - degraded_model_t = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c) - degraded_model_uncond = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, unconditional_conditioning) - degraded_model_output = degraded_model_uncond + unconditional_guidance_scale * (degraded_model_t - degraded_model_uncond) - # print("sag_scale",sag_scale) - model_output += sag_scale * (model_output - degraded_model_output) - # model_output = (1-sag_scale) * model_output + sag_scale * degraded_model_output - - # current prediction for x_0 - if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - num_reference_steps = timesteps.shape[0] - - assert t_enc <= num_reference_steps - num_steps = t_enc - - if use_original_steps: - alphas_next = self.alphas_cumprod[:num_steps] - alphas = self.alphas_cumprod_prev[:num_steps] - else: - alphas_next = self.ddim_alphas[:num_steps] - alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) - - x_next = x0 - intermediates = [] - inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): - t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: - noise_pred = self.model.apply_model(x_next, t, c) - else: - assert unconditional_conditioning is not None - e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) - noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) - - xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred - x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: - intermediates.append(x_next) - inter_steps.append(i) - elif return_intermediates and i >= num_steps - 2: - intermediates.append(x_next) - inter_steps.append(i) - if callback: callback(i) - - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} - if return_intermediates: - out.update({'intermediates': intermediates}) - return x_next, out - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) - - @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec diff --git a/Control-Color/cldm/ddim_hacked_sag.py b/Control-Color/cldm/ddim_hacked_sag.py deleted file mode 100644 index 50c866c14a29ad89c8b2c9d6396559c01d769d0b..0000000000000000000000000000000000000000 --- a/Control-Color/cldm/ddim_hacked_sag.py +++ /dev/null @@ -1,543 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor -import torch.nn.functional as F - -import cv2 -# Gaussian blur -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = F.pad(img, padding, mode="reflect") - img = F.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img - -# processes and stores attention probabilities -class CrossAttnStoreProcessor: - def __init__(self): - self.attention_probs = None - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - self.attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(self.attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - masked_image_latents=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - sag_scale=0.75, - SAG_influence_step=600, - noise = None, - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask,masked_image_latents=masked_image_latents, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - sag_scale = sag_scale, - SAG_influence_step = SAG_influence_step, - noise = noise, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) - return samples, intermediates - - def add_noise(self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - betas = torch.linspace(0.00085, 0.0120, 1000, dtype=torch.float32) - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - - return noisy_samples - # def add_noise( - # self, - # original_samples: torch.FloatTensor, - # noise: torch.FloatTensor, - # timesteps: torch.FloatTensor, - # sigma_t, - # ) -> torch.FloatTensor: - - # # Make sure sigmas and timesteps have the same device and dtype as original_samples - - # sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - # if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # # mps does not support float64 - # schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - # timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - # else: - # schedule_timesteps = self.timesteps.to(original_samples.device) - # timesteps = timesteps.to(original_samples.device) - - # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - # sigma = sigmas[step_indices].flatten() - # while len(sigma.shape) < len(original_samples.shape): - # sigma = sigma.unsqueeze(-1) - # # print(sigma_t) - # noisy_samples = original_samples + noise * sigma_t - # return noisy_samples - - - def sag_masking(self, original_latents,model_output,x, attn_map, map_size, t, eps): - # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf - bh, hw1, hw2 = attn_map.shape - b, latent_channel, latent_h, latent_w = original_latents.shape - h = 4 #self.unet.config.attention_head_dim - if isinstance(h, list): - h = h[-1] - # print(attn_map.shape) - # print(original_latents.shape) - # print(map_size) - # Produce attention mask - attn_map = attn_map.reshape(b, h, hw1, hw2) - attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 - # print(attn_mask.shape) - attn_mask = ( - attn_mask.reshape(b, map_size[0], map_size[1]) - .unsqueeze(1) - .repeat(1, latent_channel, 1, 1) - .type(attn_map.dtype) - ) - attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) - # print(attn_mask.shape) - # cv2.imwrite("attn_mask.png",attn_mask) - # Blur according to the self-attention mask - degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) - # degraded_latents = self.add_noise(degraded_latents, noise=eps, timesteps=t)#,sigma_t=sigma_t) - degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) #x#original_latents - # degraded_latents = self.model.get_x_t_from_start_and_t(degraded_latents,t,model_output) - # print(original_latents.shape) - # print(eps.shape) - # Noise it again to match the noise level - # print("t",t) - # degraded_latents = self.add_noise(degraded_latents, noise=eps, timesteps=t)#,sigma_t=sigma_t) - - return degraded_latents - - def pred_epsilon(self, sample, model_output, timestep): - alpha_prod_t = timestep - - beta_prod_t = 1 - alpha_prod_t - # print(self.model.parameterization)#eps - if self.model.parameterization == "eps": - pred_eps = model_output - elif self.model.parameterization == "sample": - pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) - elif self.model.parameterization == "v": - pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output - else: - raise ValueError( - f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `eps`, `sample`," - " or `v`" - ) - - return pred_eps - - @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None,masked_image_latents=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1.,sag_scale = 0.75, SAG_influence_step=600, sag_enable = True, noise = None, unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - # timesteps =100 - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - # timesteps=timesteps[:-3] - # print("timesteps",timesteps) - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) - - for i, step in enumerate(iterator): - print(step) - if step > SAG_influence_step: - sag_enable_t=True - else: - sag_enable_t=False - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - # if mask is not None: - # assert x0 is not None - # img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - # img = img_orig * mask + (1. - mask) * img - - if ucg_schedule is not None: - assert len(ucg_schedule) == len(time_range) - unconditional_guidance_scale = ucg_schedule[i] - - outs = self.p_sample_ddim(img,mask,masked_image_latents, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - sag_scale = sag_scale, - sag_enable=sag_enable_t, - noise =noise, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) - img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim(self, x,mask,masked_image_latents, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1.,sag_scale = 0.75, sag_enable=True, noise=None, unconditional_conditioning=None, - dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - # map_size = None - # def get_map_size(module, input, output): - # nonlocal map_size - # map_size = output.shape[-2:] - - # store_processor = CrossAttnStoreProcessor() - # for name, param in self.model.model.diffusion_model.named_parameters(): - # print(name) - # self.model.control_model.middle_block[1].transformer_blocks[0].attn1.processor = store_processor - # print(self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1) - # self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1 = store_processor - - # with self.model.model.diffusion_model.middle_block[1].register_forward_hook(get_map_size): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - model_output = self.model.apply_model(x,mask,masked_image_latents, t, c) - else: - model_t = self.model.apply_model(x,mask,masked_image_latents, t, c) - model_uncond = self.model.apply_model(x,mask,masked_image_latents, t, unconditional_conditioning) - model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) - - if self.model.parameterization == "v": - e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) - else: - e_t = model_output - - if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - if sag_enable == True: - uncond_attn, cond_attn = self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1.attention_probs.chunk(2) - # self-attention-based degrading of latents - map_size = self.model.model.diffusion_model.middle_block[1].map_size - degraded_latents = self.sag_masking( - pred_x0,model_output,x,uncond_attn, map_size, t, eps = noise, #self.pred_epsilon(x, model_uncond, self.model.alphas_cumprod[t]),#noise - ) - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - degraded_model_output = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c) - else: - degraded_model_t = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c) - degraded_model_uncond = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, unconditional_conditioning) - degraded_model_output = degraded_model_uncond + unconditional_guidance_scale * (degraded_model_t - degraded_model_uncond) - # print("sag_scale",sag_scale) - model_output += sag_scale * (model_output - degraded_model_output) - # model_output = (1-sag_scale) * model_output + sag_scale * degraded_model_output - - # current prediction for x_0 - if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - num_reference_steps = timesteps.shape[0] - - assert t_enc <= num_reference_steps - num_steps = t_enc - - if use_original_steps: - alphas_next = self.alphas_cumprod[:num_steps] - alphas = self.alphas_cumprod_prev[:num_steps] - else: - alphas_next = self.ddim_alphas[:num_steps] - alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) - - x_next = x0 - intermediates = [] - inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): - t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: - noise_pred = self.model.apply_model(x_next, t, c) - else: - assert unconditional_conditioning is not None - e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) - noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) - - xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred - x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: - intermediates.append(x_next) - inter_steps.append(i) - elif return_intermediates and i >= num_steps - 2: - intermediates.append(x_next) - inter_steps.append(i) - if callback: callback(i) - - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} - if return_intermediates: - out.update({'intermediates': intermediates}) - return x_next, out - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) - - @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec diff --git a/Control-Color/cldm/hack.py b/Control-Color/cldm/hack.py deleted file mode 100644 index 454361e9d036cd1a6a79122c2fd16b489e4767b1..0000000000000000000000000000000000000000 --- a/Control-Color/cldm/hack.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import einops - -import ldm.modules.encoders.modules -import ldm.modules.attention - -from transformers import logging -from ldm.modules.attention import default - - -def disable_verbosity(): - logging.set_verbosity_error() - print('logging improved.') - return - - -def enable_sliced_attention(): - ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward - print('Enabled sliced_attention.') - return - - -def hack_everything(clip_skip=0): - disable_verbosity() - ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward - ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip - print('Enabled clip hacks.') - return - - -# Written by Lvmin -def _hacked_clip_forward(self, text): - PAD = self.tokenizer.pad_token_id - EOS = self.tokenizer.eos_token_id - BOS = self.tokenizer.bos_token_id - - def tokenize(t): - return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] - - def transformer_encode(t): - if self.clip_skip > 1: - rt = self.transformer(input_ids=t, output_hidden_states=True) - return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) - else: - return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state - - def split(x): - return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] - - def pad(x, p, i): - return x[:i] if len(x) >= i else x + [p] * (i - len(x)) - - raw_tokens_list = tokenize(text) - tokens_list = [] - - for raw_tokens in raw_tokens_list: - raw_tokens_123 = split(raw_tokens) - raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] - raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] - tokens_list.append(raw_tokens_123) - - tokens_list = torch.IntTensor(tokens_list).to(self.device) - - feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') - y = transformer_encode(feed) - z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) - - return z - - -# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py -def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - limit = k.shape[0] - att_step = 1 - q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) - k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) - v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) - - q_chunks.reverse() - k_chunks.reverse() - v_chunks.reverse() - sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - del k, q, v - for i in range(0, limit, att_step): - q_buffer = q_chunks.pop() - k_buffer = k_chunks.pop() - v_buffer = v_chunks.pop() - sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale - - del k_buffer, q_buffer - # attention, what we cannot get enough of, by chunks - - sim_buffer = sim_buffer.softmax(dim=-1) - - sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) - del v_buffer - sim[i:i + att_step, :, :] = sim_buffer - - del sim_buffer - sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) - return self.to_out(sim) diff --git a/Control-Color/cldm/model.py b/Control-Color/cldm/model.py deleted file mode 100644 index fed3c31ac145b78907c7f771d1d8db6fb32d92ed..0000000000000000000000000000000000000000 --- a/Control-Color/cldm/model.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import torch - -from omegaconf import OmegaConf -from ldm.util import instantiate_from_config - - -def get_state_dict(d): - return d.get('state_dict', d) - - -def load_state_dict(ckpt_path, location='cpu'): - _, extension = os.path.splitext(ckpt_path) - if extension.lower() == ".safetensors": - import safetensors.torch - state_dict = safetensors.torch.load_file(ckpt_path, device=location) - else: - state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) - state_dict = get_state_dict(state_dict) - print(f'Loaded state_dict from [{ckpt_path}]') - return state_dict - - -def create_model(config_path): - config = OmegaConf.load(config_path) - model = instantiate_from_config(config.model).cpu() - print(f'Loaded model config from [{config_path}]') - return model diff --git a/Control-Color/config.py b/Control-Color/config.py deleted file mode 100644 index e0c738d8cbad66bbe1666284aef926c326849701..0000000000000000000000000000000000000000 --- a/Control-Color/config.py +++ /dev/null @@ -1 +0,0 @@ -save_memory = False diff --git a/Control-Color/ldm/__pycache__/util.cpython-38.pyc b/Control-Color/ldm/__pycache__/util.cpython-38.pyc deleted file mode 100644 index 5562e0ca576e3e4ff95d6f5b8545f5f121d2885c..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/Control-Color/ldm/models/__pycache__/autoencoder.cpython-38.pyc deleted file mode 100644 index 08646db4cc4361a58ad55ac34fd8e13f825ac45e..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/models/__pycache__/autoencoder.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/models/__pycache__/autoencoder_train.cpython-38.pyc b/Control-Color/ldm/models/__pycache__/autoencoder_train.cpython-38.pyc deleted file mode 100644 index cc933d135fb1751dc5987600ca27081f90fbc98e..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/models/__pycache__/autoencoder_train.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/models/autoencoder.py b/Control-Color/ldm/models/autoencoder.py deleted file mode 100644 index 1a2031ca9ee7e389063e74bfc4aa7479f98027b6..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/autoencoder.py +++ /dev/null @@ -1,220 +0,0 @@ -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -from contextlib import contextmanager - -# from ldm.modules.diffusionmodules.model_window import Encoder, Decoder -from ldm.modules.diffusionmodules.model_brefore_dcn import Encoder, Decoder -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution - -from ldm.util import instantiate_from_config -from ldm.modules.ema import LitEma - - -class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): - super().__init__() - self.learn_logvar = learn_logvar - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - - self.use_ema = ema_decay is not None - if self.use_ema: - self.ema_decay = ema_decay - assert 0. < ema_decay < 1. - self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return aeloss - - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, postfix=""): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) - if self.learn_logvar: - print(f"{self.__class__.__name__}: Learning logvar") - ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - if log_ema or self.use_ema: - with self.ema_scope(): - xrec_ema, posterior_ema = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec_ema.shape[1] > 3 - xrec_ema = self.to_rgb(xrec_ema) - log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) - log["reconstructions_ema"] = xrec_ema - log["inputs"] = x - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x - diff --git a/Control-Color/ldm/models/autoencoder_train.py b/Control-Color/ldm/models/autoencoder_train.py deleted file mode 100644 index ba158b44527c5e5c6dba868ad6b052fca59863d0..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/autoencoder_train.py +++ /dev/null @@ -1,299 +0,0 @@ -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -from contextlib import contextmanager - -from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution - -from ldm.util import instantiate_from_config -from ldm.modules.ema import LitEma - -import random -import cv2 - -# from cldm.model import create_model, load_state_dict -# model = create_model('./models/cldm_v15_inpainting.yaml') -# resume_path = "/data/2023text2edit/ControlNet/ckpt_inpainting_from5625+5625/epoch0_global-step3750.ckpt" -# model.load_state_dict(load_state_dict(resume_path, location='cpu'),strict=True) -# model.half() -# model.cuda() - -class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="input", - output_key="jpg", - gray_image_key="gray", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): - super().__init__() - self.learn_logvar = learn_logvar - self.image_key = image_key - self.gray_image_key = gray_image_key - self.output_key=output_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - - # model = create_model('./models/cldm_v15_inpainting.yaml') - # resume_path = "/data/2023text2edit/ControlNet/ckpt_inpainting_from5625+5625/epoch0_global-step3750.ckpt" - # model.load_state_dict(load_state_dict(resume_path, location='cpu'),strict=True) - # model.half() - # self.model=model.cuda() - # # self.model=model.eval() - # for param in self.model.parameters(): - # param.requires_grad = False - - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - - self.use_ema = ema_decay is not None - if self.use_ema: - self.ema_decay = ema_decay - assert 0. < ema_decay < 1. - self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z,gray_content_z): - z = self.post_quant_conv(z) - gray_content_z = self.post_quant_conv(gray_content_z) - dec = self.decoder(z,gray_content_z) - return dec - - def forward(self, input,gray_image, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - gray_posterior = self.encode(gray_image) - if sample_posterior: - gray_content_z = gray_posterior.sample() - else: - gray_content_z = gray_posterior.mode() - dec = self.decode(z,gray_content_z) - return dec, posterior - - def get_input(self, batch,k0, k1,k2): - # print(batch) - # print(k) - # x = batch[k] - # if len(x.shape) == 3: - # x = x[..., None] - # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - gray_image = batch[k2] - if len(gray_image.shape) == 3: - gray_image = gray_image[..., None] - gray_image = gray_image.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - - - # t=random.randint(1,100)#120 - # print(t) - # model=model.cuda() - x = batch[k0]#self.model.get_noised_images(((gt.squeeze(0)+1.0)/2.0).permute(2,0,1).to(memory_format=torch.contiguous_format).type(torch.HalfTensor).cuda(),t=torch.Tensor([t]).long().cuda()) - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - # x = x.float() - # torch.cuda.empty_cache() - # print(input.shape) - # cv2.imwrite("tttt.png",cv2.cvtColor(x.squeeze(0).permute(1,2,0).cpu().numpy()*255.0, cv2.COLOR_RGB2BGR)) - # x = x*2.0-1.0 - # x = x.squeeze(0).permute(1,2,0).cpu().numpy()*2.0-1.0 - # if len(x.shape) == 3: - # x = x[..., None] - # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) - gt = batch[k1] - if len(gt.shape) == 3: - gt = gt[..., None] - gt = gt.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - - return gt,x,gray_image - - def training_step(self, batch, batch_idx, optimizer_idx): - with torch.no_grad(): - outputs,inputs,gray_images = self.get_input(batch, self.image_key,self.output_key,self.gray_image_key) - reconstructions, posterior = self(inputs,gray_images) - - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(outputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - # print(aeloss) - return aeloss - - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss(outputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - # print(discloss) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, postfix=""): - outputs,inputs,gray_images = self.get_input(batch, self.image_key,self.output_key,self.gray_image_key) - reconstructions, posterior = self(inputs,gray_images) - aeloss, log_dict_ae = self.loss(outputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(outputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - # ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - # self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) - # for name,param in self.decoder.named_parameters(): - # if "dcn" in name: - # print(name) - ae_params_list = list(self.decoder.dcn_in.parameters())+list(self.decoder.mid.block_1.dcn1.parameters())+list(self.decoder.mid.block_1.dcn2.parameters())+list(self.decoder.mid.block_2.dcn1.parameters())+list(self.decoder.mid.block_2.dcn2.parameters()) - # print(ae_params_list) - # for i in ae_params_list: - # print(i) - if self.learn_logvar: - print(f"{self.__class__.__name__}: Learning logvar") - ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def get_gray_content_z(self,gray_image): - # if len(gray_image.shape) == 3: - # gray_image = gray_image[..., None] - gray_image = gray_image.unsqueeze(0).permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - gray_content_z=self.encode(gray_image) - gray_content_z = gray_content_z.sample() - return gray_content_z - - @torch.no_grad() - def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): - log = dict() - gt,x,gray_image = self.get_input(batch, self.image_key,self.output_key,self.gray_image_key) - log['gt']=gt - x = x.to(self.device) - gray_image = gray_image.to(self.device) - if not only_inputs: - xrec, posterior = self(x,gray_image) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - gray_image = self.to_rgb(gray_image) - xrec = self.to_rgb(xrec) - gray_content_z=self.encode(gray_image) - gray_content_z = gray_content_z.sample() - log["samples"] = self.decode(torch.randn_like(posterior.sample()),gray_content_z) - log["reconstructions"] = xrec - if log_ema or self.use_ema: - with self.ema_scope(): - xrec_ema, posterior_ema = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec_ema.shape[1] > 3 - xrec_ema = self.to_rgb(xrec_ema) - log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) - log["reconstructions_ema"] = xrec_ema - log["inputs"] = x - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x - diff --git a/Control-Color/ldm/models/diffusion/__init__.py b/Control-Color/ldm/models/diffusion/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/Control-Color/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/Control-Color/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 8203a0ba47940607d37eda2eb402fa0146802605..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/Control-Color/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc deleted file mode 100644 index 9f6c961f543c590160df7fef1155f66a835ff80f..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/Control-Color/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc deleted file mode 100644 index 7cb130c1701f6dfc6a223e61fde399140c1aa81b..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/models/diffusion/__pycache__/ddpm_nonoise.cpython-38.pyc b/Control-Color/ldm/models/diffusion/__pycache__/ddpm_nonoise.cpython-38.pyc deleted file mode 100644 index 1cccbb62b2dbc9dfba49c3b981ffd583a6e15de8..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/models/diffusion/__pycache__/ddpm_nonoise.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/models/diffusion/ddim.py b/Control-Color/ldm/models/diffusion/ddim.py deleted file mode 100644 index 37a82117e53e9f9e2cc5b6831601608e18b1950d..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/diffusion/ddim.py +++ /dev/null @@ -1,337 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - masked_image_latents=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask,masked_image_latents=masked_image_latents, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) - return samples, intermediates - - @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None,masked_image_latents=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - # if mask is not None: - # assert x0 is not None - # img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - # img = img_orig * mask + (1. - mask) * img - - if ucg_schedule is not None: - assert len(ucg_schedule) == len(time_range) - unconditional_guidance_scale = ucg_schedule[i] - - outs = self.p_sample_ddim(img,mask,masked_image_latents, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) - img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim(self, x,mask,masked_image_latents, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - model_output = self.model.apply_model(x,mask,masked_image_latents, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = dict() - for k in c: - if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] - else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) - elif isinstance(c, list): - c_in = list() - assert isinstance(unconditional_conditioning, list) - for i in range(len(c)): - c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) - else: - c_in = torch.cat([unconditional_conditioning, c]) - model_uncond, model_t = self.model.apply_model(x_in,mask,masked_image_latents, t_in, c_in).chunk(2) - model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) - - if self.model.parameterization == "v": - e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) - else: - e_t = model_output - - if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): - num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] - - assert t_enc <= num_reference_steps - num_steps = t_enc - - if use_original_steps: - alphas_next = self.alphas_cumprod[:num_steps] - alphas = self.alphas_cumprod_prev[:num_steps] - else: - alphas_next = self.ddim_alphas[:num_steps] - alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) - - x_next = x0 - intermediates = [] - inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): - t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: - noise_pred = self.model.apply_model(x_next, t, c) - else: - assert unconditional_conditioning is not None - e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) - noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) - - xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred - x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: - intermediates.append(x_next) - inter_steps.append(i) - elif return_intermediates and i >= num_steps - 2: - intermediates.append(x_next) - inter_steps.append(i) - if callback: callback(i) - - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} - if return_intermediates: - out.update({'intermediates': intermediates}) - return x_next, out - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) - - @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file diff --git a/Control-Color/ldm/models/diffusion/ddpm.py b/Control-Color/ldm/models/diffusion/ddpm.py deleted file mode 100644 index 37c699c234e0844ea2ec9c80486d207531837d01..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/diffusion/ddpm.py +++ /dev/null @@ -1,1911 +0,0 @@ -""" -wild mixture of -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py -https://github.com/CompVis/taming-transformers --- merci -""" - -import torch -import torch.nn as nn -import numpy as np -import pytorch_lightning as pl -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange, repeat -from contextlib import contextmanager, nullcontext -from functools import partial -import itertools -from tqdm import tqdm -from torchvision.utils import make_grid -from pytorch_lightning.utilities.distributed import rank_zero_only -from omegaconf import ListConfig - -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler - - -__conditioning_keys__ = {'concat': 'c_concat', - 'crossattn': 'c_crossattn', - 'adm': 'y'} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def uniform_on_device(r1, r2, shape, device): - return (r1 - r2) * torch.rand(*shape, device=device) + r2 - -def prepare_mask_latents( - mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) - - mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask - masked_image_latents = ( - torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - return mask, masked_image_latents - -class DDPM(pl.LightningModule): - # classic DDPM with Gaussian diffusion, in image space - def __init__(self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0., - make_it_fit=False, - ucg_training=None, - reset_ema=False, - reset_num_ema_updates=False, - ): - super().__init__() - assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' - self.parameterization = parameterization - print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") - self.cond_stage_model = None - self.clip_denoised = clip_denoised - self.log_every_t = log_every_t - self.first_stage_key = first_stage_key - self.image_size = image_size # try conv? - self.channels = channels - self.use_positional_encodings = use_positional_encodings - self.model = DiffusionWrapper(unet_config, conditioning_key) - count_params(self.model, verbose=True) - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.use_scheduler = scheduler_config is not None - if self.use_scheduler: - self.scheduler_config = scheduler_config - - self.v_posterior = v_posterior - self.original_elbo_weight = original_elbo_weight - self.l_simple_weight = l_simple_weight - - if monitor is not None: - self.monitor = monitor - self.make_it_fit = make_it_fit - if reset_ema: assert exists(ckpt_path) - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) - if reset_ema: - assert self.use_ema - print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") - self.model_ema = LitEma(self.model) - if reset_num_ema_updates: - print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") - assert self.use_ema - self.model_ema.reset_num_updates() - - self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) - - self.loss_type = loss_type - - self.learn_logvar = learn_logvar - logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) - if self.learn_logvar: - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - else: - self.register_buffer('logvar', logvar) - - self.ucg_training = ucg_training or dict() - if self.ucg_training: - self.ucg_prng = np.random.RandomState() - - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) - - if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) - elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) - elif self.parameterization == "v": - lvlb_weights = torch.ones_like(self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) - else: - raise NotImplementedError("mu not supported") - lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) - assert not torch.isnan(self.lvlb_weights).all() - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - @torch.no_grad() - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - if self.make_it_fit: - n_params = len([name for name, _ in - itertools.chain(self.named_parameters(), - self.named_buffers())]) - for name, param in tqdm( - itertools.chain(self.named_parameters(), - self.named_buffers()), - desc="Fitting old weights to new weights", - total=n_params - ): - if not name in sd: - continue - old_shape = sd[name].shape - new_shape = param.shape - assert len(old_shape) == len(new_shape) - if len(new_shape) > 2: - # we only modify first two axes - assert new_shape[2:] == old_shape[2:] - # assumes first axis corresponds to output dim - if not new_shape == old_shape: - new_param = param.clone() - old_param = sd[name] - if len(new_shape) == 1: - for i in range(new_param.shape[0]): - new_param[i] = old_param[i % old_shape[0]] - elif len(new_shape) >= 2: - for i in range(new_param.shape[0]): - for j in range(new_param.shape[1]): - new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] - - n_used_old = torch.ones(old_shape[1]) - for j in range(new_param.shape[1]): - n_used_old[j % old_shape[1]] += 1 - n_used_new = torch.zeros(new_shape[1]) - for j in range(new_param.shape[1]): - n_used_new[j] = n_used_old[j % old_shape[1]] - - n_used_new = n_used_new[None, :] - while len(n_used_new.shape) < len(new_shape): - n_used_new = n_used_new.unsqueeze(-1) - new_param /= n_used_new - - sd[name] = new_param - - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: - print(f"Missing Keys:\n {missing}") - if len(unexpected) > 0: - print(f"\nUnexpected Keys:\n {unexpected}") - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) - variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def predict_start_from_z_and_v(self, x_t, t, v): - # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v - ) - - # def get_x_t_from_start_and_t(self, start, t, v): - # return ( - # (start+extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, start.shape) * v)/extract_into_tensor(self.sqrt_alphas_cumprod, t, start.shape) - # ) - - def predict_eps_from_z_and_v(self, x_t, t, v): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_out = self.model(x, t) - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - if clip_denoised: - x_recon.clamp_(-1., 1.) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_loop(self, shape, return_intermediates=False): - device = self.betas.device - b = shape[0] - img = torch.randn(shape, device=device) - intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) - if i % self.log_every_t == 0 or i == self.num_timesteps - 1: - intermediates.append(img) - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, batch_size=16, return_intermediates=False): - image_size = self.image_size - channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) - - def get_v(self, x, noise, t): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x - ) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == 'l2': - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - - return loss - - def p_losses(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_out = self.model(x_noisy, t) - - loss_dict = {} - if self.parameterization == "eps": - target = noise - elif self.parameterization == "x0": - target = x_start - elif self.parameterization == "v": - target = self.get_v(x_start, noise, t) - else: - raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") - - loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - - log_prefix = 'train' if self.training else 'val' - - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) - loss_simple = loss.mean() * self.l_simple_weight - - loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) - - loss = loss_simple + self.original_elbo_weight * loss_vlb - - loss_dict.update({f'{log_prefix}/loss': loss}) - - return loss, loss_dict - - def forward(self, x, *args, **kwargs): - # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size - # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - return self.p_losses(x, t, *args, **kwargs) - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') - x = x.to(memory_format=torch.contiguous_format).float() - return x - - def shared_step(self, batch): - x = self.get_input(batch, self.first_stage_key) - loss, loss_dict = self(x) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - for k in self.ucg_training: - p = self.ucg_training[k]["p"] - val = self.ucg_training[k]["val"] - if val is None: - val = "" - for i in range(len(batch[k])): - if self.ucg_prng.choice(2, p=[1 - p, p]): - batch[k][i] = val - - loss, loss_dict = self.shared_step(batch) - - self.log_dict(loss_dict, prog_bar=True, - logger=True, on_step=True, on_epoch=True) - - self.log("global_step", self.global_step, - prog_bar=True, logger=True, on_step=True, on_epoch=False) - - if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) - - return loss - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - _, loss_dict_no_ema = self.shared_step(batch) - with self.ema_scope(): - _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} - self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) - self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - def _get_rows_from_list(self, samples): - n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() - x = self.get_input(batch, self.first_stage_key) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - x = x.to(self.device)[:N] - log["inputs"] = x - - # get diffusion row - diffusion_row = list() - x_start = x[:n_row] - - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - diffusion_row.append(x_noisy) - - log["diffusion_row"] = self._get_rows_from_list(diffusion_row) - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) - - log["samples"] = samples - log["denoise_row"] = self._get_rows_from_list(denoise_row) - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt - - -class LatentDiffusion(DDPM): - """main class""" - - def __init__(self, - first_stage_config, - cond_stage_config, - contextual_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - force_null_conditioning=False, - masked_image=None, - mask=None, - load_loss=False, - *args, **kwargs): - self.masked_image=masked_image - self.mask=mask - self.load_loss=load_loss - self.force_null_conditioning = force_null_conditioning - self.num_timesteps_cond = default(num_timesteps_cond, 1) - self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] - # for backwards compatibility after implementation of DiffusionWrapper - if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: - conditioning_key = None - ckpt_path = kwargs.pop("ckpt_path", None) - reset_ema = kwargs.pop("reset_ema", False) - reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False) - ignore_keys = kwargs.pop("ignore_keys", []) - # print(conditioning_key) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) - self.concat_mode = concat_mode - self.cond_stage_trainable = cond_stage_trainable - self.cond_stage_key = cond_stage_key - try: - self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: - self.num_downs = 0 - if not scale_by_std: - self.scale_factor = scale_factor - else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) - self.instantiate_first_stage(first_stage_config) - self.instantiate_cond_stage(cond_stage_config) - self.instantiate_contextual_stage(contextual_stage_config) - self.cond_stage_forward = cond_stage_forward - self.clip_denoised = False - self.bbox_tokenizer = None - - self.restarted_from_ckpt = False - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys) - self.restarted_from_ckpt = True - if reset_ema: - assert self.use_ema - print( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") - self.model_ema = LitEma(self.model) - if reset_num_ema_updates: - print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") - assert self.use_ema - self.model_ema.reset_num_updates() - - def make_cond_schedule(self, ): - self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) - ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids - - @rank_zero_only - @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' - # set rescale weight to 1./std of encodings - print("### USING STD-RESCALING ###") - x = super().get_input(batch, self.first_stage_key) - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) - print(f"setting self.scale_factor to {self.scale_factor}") - print("### USING STD-RESCALING ###") - - def register_schedule(self, - given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) - - self.shorten_cond_schedule = self.num_timesteps_cond > 1 - if self.shorten_cond_schedule: - self.make_cond_schedule() - - def instantiate_first_stage(self, config): - model = instantiate_from_config(config) - self.first_stage_model = model.eval() - self.first_stage_model.train = disabled_train - for param in self.first_stage_model.parameters(): - param.requires_grad = False - - def instantiate_contextual_stage(self, config): - if self.load_loss==True: - model = instantiate_from_config(config) - model.load_state_dict(torch.load("/mnt/lustre/zxliang/zcli/data/vgg19_conv.pth"), strict=False) - print("vgg loaded") - self.contextual_stage_model = model.eval() - for param in self.contextual_stage_model.parameters(): - param.requires_grad = False - self.contextual_loss = ContextualLoss().to(self.device) - elif self.load_loss==False: - self.contextual_stage_model = None - self.contextual_loss = None - else: - print("ERROR!!!!!self.load_loss should be either True or False!!!") - - def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") - self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: - assert config != '__is_first_stage__' - assert config != '__is_unconditional__' - model = instantiate_from_config(config) - self.cond_stage_model = model - - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): - denoise_row = [] - for zd in tqdm(samples, desc=desc): - denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) - n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") - return self.scale_factor * z - - def get_learned_conditioning(self, c): - if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): - c = self.cond_stage_model.encode(c) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - else: - c = self.cond_stage_model(c) - else: - assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) - return c - - def meshgrid(self, h, w): - y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) - x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) - - arr = torch.cat([y, x], dim=-1) - return arr - - def delta_border(self, h, w): - """ - :param h: height - :param w: width - :return: normalized distance to image border, - wtith min distance = 0 at border and max dist = 0.5 at image center - """ - lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) - arr = self.meshgrid(h, w) / lower_right_corner - dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] - dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] - return edge_dist - - def get_weighting(self, h, w, Ly, Lx, device): - weighting = self.delta_border(h, w) - weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], - self.split_input_params["clip_max_weight"], ) - weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) - - if self.split_input_params["tie_braker"]: - L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, - self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) - - L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) - weighting = weighting * L_weighting - return weighting - - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code - """ - :param x: img of size (bs, c, h, w) - :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) - """ - bs, nc, h, w = x.shape - - # number of crops in image - Ly = (h - kernel_size[0]) // stride[0] + 1 - Lx = (w - kernel_size[1]) // stride[1] + 1 - - if uf == 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) - - weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) - - elif uf > 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, padding=0, - stride=(stride[0] * uf, stride[1] * uf)) - fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) - - weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) - - elif df > 1 and uf == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, padding=0, - stride=(stride[0] // df, stride[1] // df)) - fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) - - weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) - - else: - raise NotImplementedError - - return fold, unfold, normalization, weighting - - @torch.no_grad() - def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None, return_x=False): - # print("batch",batch) - # print("k",k) - x = super().get_input(batch, k) - masked_image=batch[self.masked_image] - mask=batch[self.mask] - # print(mask.shape,masked_image.shape) - mask = torch.nn.functional.interpolate(mask, size=(mask.shape[2] // 8, mask.shape[3] // 8)) - # mask=torch.cat([mask] * 2) #if do_classifier_free_guidance else mask - mask = mask.to(device="cuda",dtype=x.dtype) - do_classifier_free_guidance=False - # mask, masked_image_latents = self.prepare_mask_latents( - # mask, - # masked_image, - # batch_size * num_images_per_prompt, - # mask.shape[0], - # mask.shape[1], - # mask.dtype, - # "cuda", - # torch.manual_seed(859311133),#generator - # do_classifier_free_guidance, - # ) - # print("x",x) - if bs is not None: - x = x[:bs] - x = x.to(self.device) - - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - - masked_image_latents = self.get_first_stage_encoding(self.encode_first_stage(masked_image)).detach() - - if self.model.conditioning_key is not None and not self.force_null_conditioning: - if cond_key is None: - cond_key = self.cond_stage_key - if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox', "txt"]: - xc = batch[cond_key] - elif cond_key in ['class_label', 'cls']: - xc = batch - else: - xc = super().get_input(batch, cond_key).to(self.device) - else: - xc = x - if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): - c = self.get_learned_conditioning(xc) - else: - c = self.get_learned_conditioning(xc.to(self.device)) - else: - c = xc - if bs is not None: - c = c[:bs] - - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} - - else: - c = None - xc = None - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} - out = [z,mask,masked_image_latents, c] - if return_first_stage_outputs: - xrec = self.decode_first_stage(z) - out.extend([x, xrec]) - if return_x: - out.extend([x]) - if return_original_cond: - out.append(xc) - return out - - @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - return self.first_stage_model.decode(z) - - @torch.no_grad() - def encode_first_stage(self, x): - return self.first_stage_model.encode(x) - - @torch.no_grad() - def decode_first_stage_before_vae(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - return z - - def shared_step(self, batch, **kwargs): - x,mask,masked_image_latents, c = self.get_input(batch, self.first_stage_key) - loss = self(x,mask,masked_image_latents, c) - return loss - - def forward(self, x,mask,masked_image_latents, c, *args, **kwargs): - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - if self.model.conditioning_key is not None: - assert c is not None - if self.cond_stage_trainable: - c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option - tc = self.cond_ids[t].to(self.device) - c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) - return self.p_losses(x,mask,masked_image_latents, c, t, *args, **kwargs) - - def apply_model(self, x_noisy, t, cond, return_ids=False): - if isinstance(cond, dict): - # hybrid case, cond is expected to be a dict - pass - else: - if not isinstance(cond, list): - cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' - cond = {key: cond} - - x_recon = self.model(x_noisy, t, **cond) - - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def _prior_bpd(self, x_start): - """ - Get the prior KL term for the variational lower-bound, measured in - bits-per-dim. - This term can't be optimized, as it only depends on the encoder. - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. - """ - batch_size = x_start.shape[0] - t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) - return mean_flat(kl_prior) / np.log(2.0) - - def p_losses(self, x_start,mask,masked_image_latents, cond, t, noise=None): #latent diffusion - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_output = self.apply_model(x_noisy,mask,masked_image_latents, t, cond) - # print("before loss: ", model_output.shape) - loss_dict = {} - prefix = 'train' if self.training else 'val' - - if self.parameterization == "x0": - target = x_start - elif self.parameterization == "eps": - target = noise - elif self.parameterization == "v": - target = self.get_v(x_start, noise, t) - else: - raise NotImplementedError() - - loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - - logvar_t = self.logvar[t].to(self.device) - loss = loss_simple / torch.exp(logvar_t) + logvar_t - # loss = loss_simple / torch.exp(self.logvar) + self.logvar - if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) - - loss = self.l_simple_weight * loss.mean() - - loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) - loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) - - return loss, loss_dict - - def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, - return_x0=False, score_corrector=None, corrector_kwargs=None): - t_in = t - model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) - - if score_corrector is not None: - assert self.parameterization == "eps" - model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) - - if return_codebook_ids: - model_out, logits = model_out - - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - else: - raise NotImplementedError() - - if clip_denoised: - x_recon.clamp_(-1., 1.) - if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - if return_codebook_ids: - return model_mean, posterior_variance, posterior_log_variance, logits - elif return_x0: - return model_mean, posterior_variance, posterior_log_variance, x_recon - else: - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, - return_codebook_ids=False, quantize_denoised=False, return_x0=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): - b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) - if return_codebook_ids: - raise DeprecationWarning("Support dropped.") - model_mean, _, model_log_variance, logits = outputs - elif return_x0: - model_mean, _, model_log_variance, x0 = outputs - else: - model_mean, _, model_log_variance = outputs - - noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - - if return_codebook_ids: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) - if return_x0: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 - else: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, - img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., - score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, - log_every_t=None): - if not log_every_t: - log_every_t = self.log_every_t - timesteps = self.num_timesteps - if batch_size is not None: - b = batch_size if batch_size is not None else shape[0] - shape = [batch_size] + list(shape) - else: - b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T - intermediates = [] - if cond is not None: - if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} - else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed( - range(0, timesteps)) - if type(temperature) == float: - temperature = [temperature] * timesteps - - for i in iterator: - ts = torch.full((b,), i, device=self.device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img, x0_partial = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, return_x0=True, - temperature=temperature[i], noise_dropout=noise_dropout, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) - return img, intermediates - - @torch.no_grad() - def p_sample_loop(self, cond, shape, return_intermediates=False, - x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, start_T=None, - log_every_t=None): - - if not log_every_t: - log_every_t = self.log_every_t - device = self.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - intermediates = [img] - if timesteps is None: - timesteps = self.num_timesteps - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) - - if mask is not None: - assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match - - for i in iterator: - ts = torch.full((b,), i, device=device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised) - if mask is not None: - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) - - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, - verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None, **kwargs): - if shape is None: - shape = (batch_size, self.channels, self.image_size, self.image_size) - if cond is not None: - if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} - else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, x_T=x_T, - verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, - mask=mask, x0=x0) - - @torch.no_grad() - def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): - if ddim: - ddim_sampler = DDIMSampler(self) - shape = (self.channels, self.image_size, self.image_size) - samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, - shape, cond, verbose=False, **kwargs) - - else: - samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True, **kwargs) - - return samples, intermediates - - @torch.no_grad() - def get_unconditional_conditioning(self, batch_size, null_label=None): - if null_label is not None: - xc = null_label - if isinstance(xc, ListConfig): - xc = list(xc) - if isinstance(xc, dict) or isinstance(xc, list): - c = self.get_learned_conditioning(xc) - else: - if hasattr(xc, "to"): - xc = xc.to(self.device) - c = self.get_learned_conditioning(xc) - else: - if self.cond_stage_key in ["class_label", "cls"]: - xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) - return self.get_learned_conditioning(xc) - else: - raise NotImplementedError("todo") - if isinstance(c, list): # in case the encoder gives us a list - for i in range(len(c)): - c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) - else: - c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) - return c - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): - ema_scope = self.ema_scope if use_ema_scope else nullcontext - use_ddim = ddim_steps is not None - - log = dict() - z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec - if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, "decode"): - xc = self.cond_stage_model.decode(c) - log["conditioning"] = xc - elif self.cond_stage_key in ["caption", "txt"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) - log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', "cls"]: - try: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc - except KeyError: - # probably no "human_label" in batch - pass - elif isimage(xc): - log["conditioning"] = xc - if ismap(xc): - log["original_conditioning"] = self.to_rgb(xc) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = list() - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid - - if sample: - # get denoise row - with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) - x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): - # also display when quantizing x0 while sampling - with ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - quantize_denoised=True) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - # quantize_denoised=True) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_x0_quantized"] = x_samples - - if unconditional_guidance_scale > 1.0: - uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) - if self.model.conditioning_key == "crossattn-adm": - uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} - with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - ) - x_samples_cfg = self.decode_first_stage(samples_cfg) - log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg - - if inpaint: - # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] - mask = torch.ones(N, h, w).to(self.device) - # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. - mask = mask[:, None, ...] - with ema_scope("Plotting Inpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_inpainting"] = x_samples - log["mask"] = mask - - # outpaint - mask = 1. - mask - with ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_outpainting"] = x_samples - - if plot_progressive_rows: - with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) - prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") - log["progressive_row"] = prog_row - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print('Diffusion model optimizing logvar') - params.append(self.logvar) - opt = torch.optim.AdamW(params, lr=lr) - if self.use_scheduler: - assert 'target' in self.scheduler_config - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] - return [opt], scheduler - return opt - - @torch.no_grad() - def to_rgb(self, x): - x = x.float() - if not hasattr(self, "colorize"): - self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) - x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. - return x - - -class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) - self.diffusion_model = instantiate_from_config(diff_model_config) - self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] - - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): - if self.conditioning_key is None: - out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': - if not self.sequential_cross_attn: - cc = torch.cat(c_crossattn, 1) - else: - cc = c_crossattn - out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': - xc = torch.cat([x] + c_concat, dim=1) - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'hybrid-adm': - assert c_adm is not None - xc = torch.cat([x] + c_concat, dim=1) - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, y=c_adm) - elif self.conditioning_key == 'crossattn-adm': - assert c_adm is not None - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc, y=c_adm) - elif self.conditioning_key == 'adm': - cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) - else: - raise NotImplementedError() - - return out - - -class LatentUpscaleDiffusion(LatentDiffusion): - def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): - super().__init__(*args, **kwargs) - # assumes that neither the cond_stage nor the low_scale_model contain trainable params - assert not self.cond_stage_trainable - self.instantiate_low_stage(low_scale_config) - self.low_scale_key = low_scale_key - self.noise_level_key = noise_level_key - - def instantiate_low_stage(self, config): - model = instantiate_from_config(config) - self.low_scale_model = model.eval() - self.low_scale_model.train = disabled_train - for param in self.low_scale_model.parameters(): - param.requires_grad = False - - @torch.no_grad() - def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): - if not log_mode: - z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) - else: - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) - x_low = batch[self.low_scale_key][:bs] - x_low = rearrange(x_low, 'b h w c -> b c h w') - x_low = x_low.to(memory_format=torch.contiguous_format).float() - zx, noise_level = self.low_scale_model(x_low) - if self.noise_level_key is not None: - # get noise level from batch instead, e.g. when extracting a custom noise level for bsr - raise NotImplementedError('TODO') - - all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} - if log_mode: - # TODO: maybe disable if too expensive - x_low_rec = self.low_scale_model.decode(zx) - return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level - return z, all_conds - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, - unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, - **kwargs): - ema_scope = self.ema_scope if use_ema_scope else nullcontext - use_ddim = ddim_steps is not None - - log = dict() - z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, - log_mode=True) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec - log["x_lr"] = x_low - log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec - if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, "decode"): - xc = self.cond_stage_model.decode(c) - log["conditioning"] = xc - elif self.cond_stage_key in ["caption", "txt"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) - log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc - elif isimage(xc): - log["conditioning"] = xc - if ismap(xc): - log["original_conditioning"] = self.to_rgb(xc) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = list() - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid - - if sample: - # get denoise row - with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) - x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - if unconditional_guidance_scale > 1.0: - uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) - # TODO explore better "unconditional" choices for the other keys - # maybe guide away from empty text label and highest noise level and maximally degraded zx? - uc = dict() - for k in c: - if k == "c_crossattn": - assert isinstance(c[k], list) and len(c[k]) == 1 - uc[k] = [uc_tmp] - elif k == "c_adm": # todo: only run with text-based guidance? - assert isinstance(c[k], torch.Tensor) - #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level - uc[k] = c[k] - elif isinstance(c[k], list): - uc[k] = [c[k][i] for i in range(len(c[k]))] - else: - uc[k] = c[k] - - with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - ) - x_samples_cfg = self.decode_first_stage(samples_cfg) - log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg - - if plot_progressive_rows: - with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) - prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") - log["progressive_row"] = prog_row - - return log - - -class LatentFinetuneDiffusion(LatentDiffusion): - """ - Basis for different finetunas, such as inpainting or depth2image - To disable finetuning mode, set finetune_keys to None - """ - - def __init__(self, - concat_keys: tuple, - finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", - "model_ema.diffusion_modelinput_blocks00weight" - ), - keep_finetune_dims=4, - # if model was trained without concat mode before and we would like to keep these channels - c_concat_log_start=None, # to log reconstruction of c_concat codes - c_concat_log_end=None, - *args, **kwargs - ): - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", list()) - super().__init__(*args, **kwargs) - self.finetune_keys = finetune_keys - self.concat_keys = concat_keys - self.keep_dims = keep_finetune_dims - self.c_concat_log_start = c_concat_log_start - self.c_concat_log_end = c_concat_log_end - if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' - if exists(ckpt_path): - self.init_from_ckpt(ckpt_path, ignore_keys) - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - - # make it explicit, finetune by including extra input channels - if exists(self.finetune_keys) and k in self.finetune_keys: - new_entry = None - for name, param in self.named_parameters(): - if name in self.finetune_keys: - print( - f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") - new_entry = torch.zeros_like(param) # zero init - assert exists(new_entry), 'did not find matching parameter to modify' - new_entry[:, :self.keep_dims, ...] = sd[k] - sd[k] = new_entry - - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): - ema_scope = self.ema_scope if use_ema_scope else nullcontext - use_ddim = ddim_steps is not None - - log = dict() - z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) - c_cat, c = c["c_concat"][0], c["c_crossattn"][0] - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec - if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, "decode"): - xc = self.cond_stage_model.decode(c) - log["conditioning"] = xc - elif self.cond_stage_key in ["caption", "txt"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) - log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc - elif isimage(xc): - log["conditioning"] = xc - if ismap(xc): - log["original_conditioning"] = self.to_rgb(xc) - - if not (self.c_concat_log_start is None and self.c_concat_log_end is None): - log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = list() - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid - - if sample: - # get denoise row - with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) - x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - if unconditional_guidance_scale > 1.0: - uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) - uc_cat = c_cat - uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} - with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc_full, - ) - x_samples_cfg = self.decode_first_stage(samples_cfg) - log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg - - return log - - -class LatentInpaintDiffusion(LatentFinetuneDiffusion): - """ - can either run as pure inpainting model (only concat mode) or with mixed conditionings, - e.g. mask as concat and text via cross-attn. - To disable finetuning mode, set finetune_keys to None - """ - - def __init__(self, - concat_keys=("mask", "masked_image"), - masked_image_key="masked_image", - *args, **kwargs - ): - super().__init__(concat_keys, *args, **kwargs) - self.masked_image_key = masked_image_key - assert self.masked_image_key in concat_keys - - @torch.no_grad() - def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): - # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) - - assert exists(self.concat_keys) - c_cat = list() - for ck in self.concat_keys: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() - if bs is not None: - cc = cc[:bs] - cc = cc.to(self.device) - bchw = z.shape - if ck != self.masked_image_key: - cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) - else: - cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} - if return_first_stage_outputs: - return z, all_conds, x, xrec, xc - return z, all_conds - - @torch.no_grad() - def log_images(self, *args, **kwargs): - log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) - log["masked_image"] = rearrange(args[0]["masked_image"], - 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() - return log - - -class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): - """ - condition on monocular depth estimation - """ - - def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): - super().__init__(concat_keys=concat_keys, *args, **kwargs) - self.depth_model = instantiate_from_config(depth_stage_config) - self.depth_stage_key = concat_keys[0] - - @torch.no_grad() - def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): - # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) - - assert exists(self.concat_keys) - assert len(self.concat_keys) == 1 - c_cat = list() - for ck in self.concat_keys: - cc = batch[ck] - if bs is not None: - cc = cc[:bs] - cc = cc.to(self.device) - cc = self.depth_model(cc) - cc = torch.nn.functional.interpolate( - cc, - size=z.shape[2:], - mode="bicubic", - align_corners=False, - ) - - depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], - keepdim=True) - cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} - if return_first_stage_outputs: - return z, all_conds, x, xrec, xc - return z, all_conds - - @torch.no_grad() - def log_images(self, *args, **kwargs): - log = super().log_images(*args, **kwargs) - depth = self.depth_model(args[0][self.depth_stage_key]) - depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ - torch.amax(depth, dim=[1, 2, 3], keepdim=True) - log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. - return log - - -class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): - """ - condition on low-res image (and optionally on some spatial noise augmentation) - """ - def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, - low_scale_config=None, low_scale_key=None, *args, **kwargs): - super().__init__(concat_keys=concat_keys, *args, **kwargs) - self.reshuffle_patch_size = reshuffle_patch_size - self.low_scale_model = None - if low_scale_config is not None: - print("Initializing a low-scale model") - assert exists(low_scale_key) - self.instantiate_low_stage(low_scale_config) - self.low_scale_key = low_scale_key - - def instantiate_low_stage(self, config): - model = instantiate_from_config(config) - self.low_scale_model = model.eval() - self.low_scale_model.train = disabled_train - for param in self.low_scale_model.parameters(): - param.requires_grad = False - - @torch.no_grad() - def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): - # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) - - assert exists(self.concat_keys) - assert len(self.concat_keys) == 1 - # optionally make spatial noise_level here - c_cat = list() - noise_level = None - for ck in self.concat_keys: - cc = batch[ck] - cc = rearrange(cc, 'b h w c -> b c h w') - if exists(self.reshuffle_patch_size): - assert isinstance(self.reshuffle_patch_size, int) - cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', - p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) - if bs is not None: - cc = cc[:bs] - cc = cc.to(self.device) - if exists(self.low_scale_model) and ck == self.low_scale_key: - cc, noise_level = self.low_scale_model(cc) - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - if exists(noise_level): - all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} - else: - all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} - if return_first_stage_outputs: - return z, all_conds, x, xrec, xc - return z, all_conds - - @torch.no_grad() - def log_images(self, *args, **kwargs): - log = super().log_images(*args, **kwargs) - log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') - return log diff --git a/Control-Color/ldm/models/diffusion/dpm_solver/__init__.py b/Control-Color/ldm/models/diffusion/dpm_solver/__init__.py deleted file mode 100644 index 7427f38c07530afbab79154ea8aaf88c4bf70a08..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/diffusion/dpm_solver/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/Control-Color/ldm/models/diffusion/dpm_solver/dpm_solver.py b/Control-Color/ldm/models/diffusion/dpm_solver/dpm_solver.py deleted file mode 100644 index 095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ /dev/null @@ -1,1154 +0,0 @@ -import torch -import torch.nn.functional as F -import math -from tqdm import tqdm - - -class NoiseScheduleVP: - def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., - ): - """Create a wrapper class for the forward SDE (VP type). - *** - Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. - We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. - *** - The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). - We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). - Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: - log_alpha_t = self.marginal_log_mean_coeff(t) - sigma_t = self.marginal_std(t) - lambda_t = self.marginal_lambda(t) - Moreover, as lambda(t) is an invertible function, we also support its inverse function: - t = self.inverse_lambda(lambda_t) - =============================================================== - We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). - 1. For discrete-time DPMs: - For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: - t_i = (i + 1) / N - e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. - We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. - Args: - betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) - alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. - **Important**: Please pay special attention for the args for `alphas_cumprod`: - The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that - q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). - Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have - alpha_{t_n} = \sqrt{\hat{alpha_n}}, - and - log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). - 2. For continuous-time DPMs: - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: - Args: - beta_min: A `float` number. The smallest beta for the linear schedule. - beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. - T: A `float` number. The ending time of the forward process. - =============================================================== - Args: - schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. - Returns: - A wrapper object of the forward SDE (VP type). - - =============================================================== - Example: - # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', betas=betas) - # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) - # For continuous-time DPMs (VPSDE), linear schedule: - >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) - """ - - if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError( - "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) - - self.schedule = schedule - if schedule == 'discrete': - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) - self.schedule = schedule - if schedule == 'cosine': - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1. - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, -): - """Create a wrapper function for the noise prediction model. - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - We support four types of the diffusion model by setting `model_type`: - 1. "noise": noise prediction model. (Trained by predicting noise). - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - =============================================================== - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return -expand_dims(sigma_t, dims) * output - - def cond_grad_fn(x, t_input): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - - -class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): - """Construct a DPM-Solver. - We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). - If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). - If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). - In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. - The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. - Args: - model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): - `` - def model_fn(x, t_continuous): - return noise - `` - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. - thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. - max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. - """ - self.model = model_fn - self.noise_schedule = noise_schedule - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.max_val = max_val - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with thresholding). - """ - noise = self.noise_prediction_fn(x, t) - dims = x.dim() - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) - if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.predict_x0: - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_time_steps(self, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - Args: - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - N: A `int`. The total number of the spacing of the time steps. - device: A torch device. - Returns: - A pytorch tensor of the time steps, with the shape (N + 1,). - """ - if skip_type == 'logSNR': - lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': - t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". - Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: - - If order == 1: - We take `steps` of DPM-Solver-1 (i.e. DDIM). - - If order == 2: - - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of DPM-Solver-2. - - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If order == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. - ============================================ - Args: - order: A `int`. The max order for the solver (2 or 3). - steps: A `int`. The total number of function evaluations (NFE). - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - device: A torch device. - Returns: - orders: A list of the solver order of each step. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] - elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] - else: - orders = [3, ] * (K - 1) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [2, ] * K - else: - K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] - elif order == 1: - K = 1 - orders = [1, ] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': - # To reproduce the results in DPM-Solver paper - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) - else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] - return timesteps_outer, orders - - def denoise_to_zero_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): - """ - DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - if self.predict_x0: - phi_1 = torch.expm1(-h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - else: - phi_1 = torch.expm1(h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-2 from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the second-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 0.5 - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) - alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_1 = torch.expm1(-h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_1 = torch.expm1(h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) - ) - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} - else: - return x_t - - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-3 from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). - If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 1. / 3. - if r2 is None: - r2 = 2. / 3. - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - lambda_s2 = lambda_s + r2 * h - s1 = ns.inverse_lambda(lambda_s1) - s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) - alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_12 = torch.expm1(-r2 * h) - phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_12 = torch.expm1(r2 * h) - phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 - ) - - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} - else: - return x_t - - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): - """ - Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - ns = self.noise_schedule - dims = x.dim() - model_prev_1, model_prev_0 = model_prev_list - t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - if self.predict_x0: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 - ) - else: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 - ) - return x_t - - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): - """ - Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - model_prev_2, model_prev_1, model_prev_0 = model_prev_list - t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_1 = lambda_prev_1 - lambda_prev_2 - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) - D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) - if self.predict_x0: - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 - ) - else: - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 - ) - return x_t - - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): - """ - Singlestep DPM-Solver with the order `order` from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - r1: A `float`. The hyperparameter of the second-order or third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) - elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) - elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): - """ - Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) - elif order == 2: - return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - elif order == 3: - return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): - """ - The adaptive step size solver based on singlestep DPM-Solver. - Args: - x: A pytorch tensor. The initial value at time `t_T`. - order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - h_init: A `float`. The initial step size (for logSNR). - atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. - rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. - theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. - t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the - current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_0: A pytorch tensor. The approximated solution at time `t_0`. - [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. - """ - ns = self.noise_schedule - s = t_T * torch.ones((x.shape[0],)).to(x) - lambda_s = ns.marginal_lambda(s) - lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) - h = h_init * torch.ones_like(s).to(x) - x_prev = x - nfe = 0 - if order == 2: - r1 = 0.5 - lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) - elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) - else: - raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) - while torch.abs((s - t_0)).mean() > t_err: - t = ns.inverse_lambda(lambda_s + h) - x_lower, lower_noise_kwargs = lower_update(x, s, t) - x_higher = higher_update(x, s, t, **lower_noise_kwargs) - delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) - norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) - E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): - x = x_higher - s = t - x_prev = x_lower - lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) - nfe += order - print('adaptive solver nfe', nfe) - return x - - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): - """ - Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. - ===================================================== - We support the following algorithms for both noise prediction model and data prediction model: - - 'singlestep': - Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. - We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). - The total number of function evaluations (NFE) == `steps`. - Given a fixed NFE == `steps`, the sampling procedure is: - - If `order` == 1: - - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. - - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If `order` == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. - - 'multistep': - Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. - We initialize the first `order` values by lower order multistep solvers. - Given a fixed NFE == `steps`, the sampling procedure is: - Denote K = steps. - - If `order` == 1: - - We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. - - If `order` == 3: - - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. - - 'singlestep_fixed': - Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). - We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. - - 'adaptive': - Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). - We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. - You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs - (NFE) and the sample quality. - - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. - ===================================================== - Some advices for choosing the algorithm: - - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, - skip_type='time_uniform', method='multistep') - We support three types of `skip_type`: - - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** - - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. - - 'time_quadratic': quadratic time for the time steps. - ===================================================== - Args: - x: A pytorch tensor. The initial value at time `t_start` - e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. - steps: A `int`. The total number of function evaluations (NFE). - t_start: A `float`. The starting time of the sampling. - If `T` is None, we use self.noise_schedule.T (default is 1.0). - t_end: A `float`. The ending time of the sampling. - If `t_end` is None, we use 1. / self.noise_schedule.total_N. - e.g. if total_N == 1000, we have `t_end` == 1e-3. - For discrete-time DPMs: - - We recommend `t_end` == 1. / self.noise_schedule.total_N. - For continuous-time DPMs: - - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. - order: A `int`. The order of DPM-Solver. - skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. - method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. - Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). - This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and - score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID - for diffusion models sampling by diffusion SDEs for low-resolutional images - (such as CIFAR-10). However, we observed that such trick does not matter for - high-resolutional images. As it needs an additional NFE, we do not recommend - it for high-resolutional images. - lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. - Only valid for `method=multistep` and `steps < 15`. We empirically find that - this trick is a key to stabilizing the sampling by DPM-Solver with very few steps - (especially for steps <= 10). So we recommend to set it to be `True`. - solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. - atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - Returns: - x_end: A pytorch tensor. The approximated solution at time `t_end`. - """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - device = x.device - if method == 'adaptive': - with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - with torch.no_grad(): - vec_t = timesteps[0].expand((x.shape[0])) - model_prev_list = [self.model_fn(x, vec_t)] - t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in tqdm(range(1, order), desc="DPM init order"): - vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) - model_prev_list.append(self.model_fn(x, vec_t)) - t_prev_list.append(vec_t) - # Compute the remaining values by `order`-th order multistep DPM-Solver. - for step in tqdm(range(order, steps + 1), desc="DPM multistep"): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final and steps < 15: - step_order = min(order, steps + 1 - step) - else: - step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, - solver_type=solver_type) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: - model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': - K = steps // order - orders = [order, ] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for i, order in enumerate(orders): - t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) - if denoise_to_zero: - x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) - return x - - -############################################################# -# other utility functions -############################################################# - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/Control-Color/ldm/models/diffusion/dpm_solver/sampler.py b/Control-Color/ldm/models/diffusion/dpm_solver/sampler.py deleted file mode 100644 index 7d137b8cf36718c1c58faa09f9dd919e5fb2977b..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/diffusion/dpm_solver/sampler.py +++ /dev/null @@ -1,87 +0,0 @@ -"""SAMPLING ONLY.""" -import torch - -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver - - -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} - - -class DPMSolverSampler(object): - def __init__(self, model, **kwargs): - super().__init__() - self.model = model - to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - - print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') - - device = self.model.betas.device - if x_T is None: - img = torch.randn(size, device=device) - else: - img = x_T - - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) - - model_fn = model_wrapper( - lambda x, t, c: self.model.apply_model(x, t, c), - ns, - model_type=MODEL_TYPES[self.model.parameterization], - guidance_type="classifier-free", - condition=conditioning, - unconditional_condition=unconditional_conditioning, - guidance_scale=unconditional_guidance_scale, - ) - - dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) - - return x.to(device), None \ No newline at end of file diff --git a/Control-Color/ldm/models/diffusion/plms.py b/Control-Color/ldm/models/diffusion/plms.py deleted file mode 100644 index 7002a365d27168ced0a04e9a4d83e088f8284eae..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/diffusion/plms.py +++ /dev/null @@ -1,244 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like -from ldm.models.diffusion.sampling_util import norm_thresholding - - -class PLMSSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, - dynamic_threshold=dynamic_threshold) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, - dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - if dynamic_threshold is not None: - pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/Control-Color/ldm/models/diffusion/sampling_util.py b/Control-Color/ldm/models/diffusion/sampling_util.py deleted file mode 100644 index 7eff02be6d7c54d43ee6680636ac0698dd3b3f33..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/diffusion/sampling_util.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import numpy as np - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions. - From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] - - -def norm_thresholding(x0, value): - s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) - return x0 * (value / s) - - -def spatial_norm_thresholding(x0, value): - # b c h w - s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file diff --git a/Control-Color/ldm/models/logger.py b/Control-Color/ldm/models/logger.py deleted file mode 100644 index a266e77eba5555b077fb2b2f59d125bf1a52b2c6..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/models/logger.py +++ /dev/null @@ -1,93 +0,0 @@ -import os - -import numpy as np -import torch -import torchvision -from PIL import Image -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities.distributed import rank_zero_only - -# import pdb - -class ImageLogger(Callback): - def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, - rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, - log_images_kwargs=None,ckpt_dir="./ckpt"): - super().__init__() - self.rescale = rescale - self.batch_freq = batch_frequency - self.max_images = max_images - if not increase_log_steps: - self.log_steps = [self.batch_freq] - self.clamp = clamp - self.disabled = disabled - self.log_on_batch_idx = log_on_batch_idx - self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} - self.log_first_step = log_first_step - self.ckpt_dir=ckpt_dir - self.global_save_num=-2000 - self.global_save_num1=-100 - - @rank_zero_only - def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): - root = os.path.join(save_dir, "image_log", split) - # print(images) - for k in images: - grid = torchvision.utils.make_grid(images[k], nrow=4) - if self.rescale: - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w - grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) - grid = grid.numpy() - grid = (grid * 255).astype(np.uint8) - filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) - path = os.path.join(root, filename) - os.makedirs(os.path.split(path)[0], exist_ok=True) - Image.fromarray(grid).save(path) - - def log_img(self, pl_module, batch, batch_idx, split="train"): - check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step - if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 - hasattr(pl_module, "log_images") and - callable(pl_module.log_images) and - self.max_images > 0): - logger = type(pl_module.logger) - - is_train = pl_module.training - if is_train: - pl_module.eval() - - with torch.no_grad(): - images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) - - for k in images: - N = min(images[k].shape[0], self.max_images) - images[k] = images[k][:N] - if isinstance(images[k], torch.Tensor): - images[k] = images[k].detach().cpu() - if self.clamp: - images[k] = torch.clamp(images[k], -1., 1.) - - self.log_local(pl_module.logger.save_dir, split, images, - pl_module.global_step, pl_module.current_epoch, batch_idx) - - if is_train: - pl_module.train() - - def check_frequency(self, check_idx): - return check_idx % self.batch_freq == 0 - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - #if not self.disabled: - #if pl_module.global_step%50 == 0: - # if pl_module.current_epoch-self.global_save_num1 > 0: - # print(batch_idx) - if batch_idx % 500 == 0: - # print("inside") - # pdb.set_trace() - # self.global_save_num1=pl_module.current_epoch - self.log_img(pl_module, batch, batch_idx, split="train_"+"ckpt_inpainting_from5625_2+3750_exemplar_only_vae") - #if pl_module.global_step%1200 == 0 and self.check_frequency(batch_idx): - if batch_idx % 1000 == 0: - # if pl_module.current_epoch-self.global_save_num>10 and self.check_frequency(batch_idx): - # self.global_save_num=pl_module.current_epoch - trainer.save_checkpoint(self.ckpt_dir+"/epoch"+str(pl_module.current_epoch)+"_global-step"+str(pl_module.global_step)+".ckpt") diff --git a/Control-Color/ldm/modules/__pycache__/attention.cpython-38.pyc b/Control-Color/ldm/modules/__pycache__/attention.cpython-38.pyc deleted file mode 100644 index cb9b54a6384c2241078fcbbbd72754c0e68c651e..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/__pycache__/attention.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/__pycache__/attention_dcn_control.cpython-38.pyc b/Control-Color/ldm/modules/__pycache__/attention_dcn_control.cpython-38.pyc deleted file mode 100644 index a60be361286b5f1dced320a6fdd0a36ea5488baf..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/__pycache__/attention_dcn_control.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/__pycache__/ema.cpython-38.pyc b/Control-Color/ldm/modules/__pycache__/ema.cpython-38.pyc deleted file mode 100644 index ea3c5073b265bb8c3a8afeb75403c5609909f6c5..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/__pycache__/ema.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/attention.py b/Control-Color/ldm/modules/attention.py deleted file mode 100644 index e274b9020a1713077b3399767d5c156966d75764..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/attention.py +++ /dev/null @@ -1,653 +0,0 @@ -from inspect import isfunction -import math -import torch -import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat -from typing import Optional, Any - -from ldm.modules.diffusionmodules.util import checkpoint - - -try: - import xformers - import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - -# CrossAttn precision handling -import os -_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") - -def exists(val): - return val is not None - - -def uniq(arr): - return{el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) - - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) - h_ = self.proj_out(h_) - - return x+h_ - - -class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head ** -0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) - self.attention_probs=None - - def forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - else: - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - del q, k - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - sim = sim.softmax(dim=-1) - self.attention_probs = sim - #print("similarity",sim.shape) - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): - super().__init__() - print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads.") - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None - self.attention_probs=None - - def forward(self, x, context=None, mask=None):#,timestep=None): - h = self.heads - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - prob=rearrange(out, 'b n (h d) -> (b h) n d', h=h) - prob = einsum('b i d, b j d -> b i j', prob, v) - self.attention_probs = prob - - # print("emb",emb) - # print(timestep) - # if prob.shape[1] ==6144 and prob.shape[2]==6144 and timestep!=None and timestep<100: #and emb==0: - # torch.save(q,"./q1.pt") - # torch.save(k,"./k1.pt") - # torch.save(prob,"./prob.pt") - # print(prob.shape) - return self.to_out(out) - - -class BasicTransformerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention - } - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False): - super().__init__() - attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" - assert attn_mode in self.ATTENTION_MODES - attn_cls = self.ATTENTION_MODES[attn_mode] - self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None):#, timestep=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) - - def _forward(self, x, context=None):#, timestep=None): - x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # type: (Tensor, float, float, float, float) -> Tensor - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - - NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are - applied while sampling the normal with mean/std applied, therefore a, b args - should be adjusted to match the range of mean, std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - Examples: - >>> w = torch.empty(3, 5) - >>> nn.init.trunc_normal_(w) - """ - with torch.no_grad(): - return _trunc_normal_(tensor, mean, std, a, b) - -class PostionalAttention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., - proj_drop=0., attn_head_dim=None, use_rpb=False, window_size=14): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * self.num_heads - self.scale = qk_scale or head_dim ** -0.5 - - self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) - self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) - else: - self.q_bias = None - self.v_bias = None - - # relative positional bias option - self.use_rpb = use_rpb - if use_rpb: - self.window_size = window_size - self.rpb_table = nn.Parameter(torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) - trunc_normal_(self.rpb_table, std=.02) - - coords_h = torch.arange(window_size) - coords_w = torch.arange(window_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w - coords_flatten = torch.flatten(coords, 1) # 2, h*w - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, h*w, h*w - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # h*w, h*w, 2 - relative_coords[:, :, 0] += window_size - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size - 1 - relative_coords[:, :, 0] *= 2 * window_size - 1 - relative_position_index = relative_coords.sum(-1) # h*w, h*w - self.register_buffer("relative_position_index", relative_position_index) - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(all_head_dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - B, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - if self.use_rpb: - relative_position_bias = self.rpb_table[self.relative_position_index.view(-1)].view( - self.window_size * self.window_size, self.window_size * self.window_size, -1) # h*w,h*w,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, h*w, h*w - attn += relative_position_bias - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - # x = self.drop(x) - # commit this for the orignal BERT implement - x = self.fc2(x) - x = self.drop(x) - return x - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, - attn_head_dim=None, use_rpb=False, window_size=14): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = PostionalAttention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim, - use_rpb=use_rpb, window_size=window_size) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = nn.Identity() #DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if init_values > 0: - self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - else: - self.gamma_1, self.gamma_2 = None, None - - def forward(self, x): - if self.gamma_1 is None: - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - else: - x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) - x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) - return x - -class PatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, mask_cent=False): - super().__init__() - # to_2tuple = _ntuple(2) - # img_size = to_2tuple(img_size) - # patch_size = to_2tuple(patch_size) - img_size = tuple((img_size, img_size)) - patch_size = tuple((patch_size,patch_size)) - num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches - self.mask_cent = mask_cent - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - # # From PyTorch internals - # def _ntuple(n): - # def parse(x): - # if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - # return tuple(x) - # return tuple(repeat(x, n)) - # return parse - - def forward(self, x, **kwargs): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - if self.mask_cent: - x[:, -1] = x[:, -1] - 0.5 - x = self.proj(x).flatten(2).transpose(1, 2) - return x - -class CnnHead(nn.Module): - def __init__(self, embed_dim, num_classes, window_size): - super().__init__() - self.embed_dim = embed_dim - self.num_classes = num_classes - self.window_size = window_size - - self.head = nn.Conv2d(embed_dim, num_classes, kernel_size=3, stride=1, padding=1, padding_mode='reflect') - - def forward(self, x): - x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) - x = self.head(x) - x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') - return x - -# sin-cos position encoding -# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 - -import numpy as np -def get_sinusoid_encoding_table(n_position, d_hid): - ''' Sinusoid position encoding table ''' - # TODO: make it with torch instead of numpy - def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - return torch.FloatTensor(sinusoid_table).unsqueeze(0) - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - NEW: use_linear for more efficiency instead of the 1x1 convs - """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, - use_checkpoint=True): - super().__init__() - if exists(context_dim) and not isinstance(context_dim, list): - context_dim = [context_dim] - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) - else: - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] - ) - if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) - else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - self.use_linear = use_linear - self.map_size = None - # self.cnnhead = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect') - - # embed_dim=192 - # img_size=64 - # patch_size=8 - # self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, - # in_chans=4, embed_dim=embed_dim, mask_cent=False) - # num_patches = self.patch_embed.num_patches # 2 - - # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) - - # self.cnnhead = CnnHead(embed_dim, num_classes=32, window_size=img_size // patch_size) - - # self.posatnn_block = Block(dim=embed_dim, num_heads=3, mlp_ratio=4., qkv_bias=True, qk_scale=None, - # drop=0., attn_drop=0., norm_layer=nn.LayerNorm, - # init_values=0., use_rpb=True, window_size=img_size // patch_size) - # # self.window_size=8 - # self.norm1=nn.LayerNorm(embed_dim) - - def forward(self, x, context=None):#,timestep=None): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i])#,timestep=timestep) - if self.use_linear: - x = self.proj_out(x) - - # x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) - # x = self.cnnhead(x) - # x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') - - # x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() - # print("before",x.shape) - - # if x.shape[1]==4: - # x = self.patch_embed(x) - # print("after PatchEmbed",x.shape) - # x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() - - # x =self.posatnn_block(x) - # x = self.norm1(x) - # print("after norm",x.shape) - - # x = self.cnnhead(x) - - # print("after",x.shape) - if not self.use_linear: - x = self.proj_out(x) - - - self.map_size = x.shape[-2:] - return x + x_in - - # res = self.cnnhead(x+x_in) - # return res - diff --git a/Control-Color/ldm/modules/attention_dcn_control.py b/Control-Color/ldm/modules/attention_dcn_control.py deleted file mode 100644 index 39d49b77f8a364080dcf27680f3a7cac39bcac52..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/attention_dcn_control.py +++ /dev/null @@ -1,854 +0,0 @@ -from inspect import isfunction -import math -import torch -import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat -from typing import Optional, Any - -from ldm.modules.diffusionmodules.util import checkpoint - -import torchvision -from torch.nn.modules.utils import _pair, _single - -try: - import xformers - import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - -# CrossAttn precision handling -import os -_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") - -def exists(val): - return val is not None - - -def uniq(arr): - return{el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) - - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) - h_ = self.proj_out(h_) - - return x+h_ - - -class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head ** -0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) - - def forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - else: - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - del q, k - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - sim = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): - super().__init__() - print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads.") - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None - - def forward(self, x, context=None, mask=None): - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - return self.to_out(out) - - -class BasicTransformerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention - } - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False): - super().__init__() - attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" - assert attn_mode in self.ATTENTION_MODES - attn_cls = self.ATTENTION_MODES[attn_mode] - self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) - - def _forward(self, x, context=None): - x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # type: (Tensor, float, float, float, float) -> Tensor - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - - NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are - applied while sampling the normal with mean/std applied, therefore a, b args - should be adjusted to match the range of mean, std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - Examples: - >>> w = torch.empty(3, 5) - >>> nn.init.trunc_normal_(w) - """ - with torch.no_grad(): - return _trunc_normal_(tensor, mean, std, a, b) - -class PostionalAttention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., - proj_drop=0., attn_head_dim=None, use_rpb=False, window_size=14): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * self.num_heads - self.scale = qk_scale or head_dim ** -0.5 - - self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) - self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) - else: - self.q_bias = None - self.v_bias = None - - # relative positional bias option - self.use_rpb = use_rpb - if use_rpb: - self.window_size = window_size - self.rpb_table = nn.Parameter(torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) - trunc_normal_(self.rpb_table, std=.02) - - coords_h = torch.arange(window_size) - coords_w = torch.arange(window_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w - coords_flatten = torch.flatten(coords, 1) # 2, h*w - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, h*w, h*w - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # h*w, h*w, 2 - relative_coords[:, :, 0] += window_size - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size - 1 - relative_coords[:, :, 0] *= 2 * window_size - 1 - relative_position_index = relative_coords.sum(-1) # h*w, h*w - self.register_buffer("relative_position_index", relative_position_index) - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(all_head_dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - B, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - if self.use_rpb: - relative_position_bias = self.rpb_table[self.relative_position_index.view(-1)].view( - self.window_size * self.window_size, self.window_size * self.window_size, -1) # h*w,h*w,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, h*w, h*w - attn += relative_position_bias - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - # x = self.drop(x) - # commit this for the orignal BERT implement - x = self.fc2(x) - x = self.drop(x) - return x - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, - attn_head_dim=None, use_rpb=False, window_size=14): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = PostionalAttention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim, - use_rpb=use_rpb, window_size=window_size) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = nn.Identity() #DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if init_values > 0: - self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - else: - self.gamma_1, self.gamma_2 = None, None - - def forward(self, x): - if self.gamma_1 is None: - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - else: - x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) - x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) - return x - -class PatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, mask_cent=False): - super().__init__() - # to_2tuple = _ntuple(2) - # img_size = to_2tuple(img_size) - # patch_size = to_2tuple(patch_size) - img_size = tuple((img_size, img_size)) - patch_size = tuple((patch_size,patch_size)) - num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches - self.mask_cent = mask_cent - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - # # From PyTorch internals - # def _ntuple(n): - # def parse(x): - # if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - # return tuple(x) - # return tuple(repeat(x, n)) - # return parse - - def forward(self, x, **kwargs): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - if self.mask_cent: - x[:, -1] = x[:, -1] - 0.5 - x = self.proj(x).flatten(2).transpose(1, 2) - return x - -class CnnHead(nn.Module): - def __init__(self, embed_dim, num_classes, window_size): - super().__init__() - self.embed_dim = embed_dim - self.num_classes = num_classes - self.window_size = window_size - - self.head = nn.Conv2d(embed_dim, num_classes, kernel_size=3, stride=1, padding=1, padding_mode='reflect') - - def forward(self, x): - x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) - x = self.head(x) - x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') - return x - -# sin-cos position encoding -# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 - -import numpy as np -def get_sinusoid_encoding_table(n_position, d_hid): - ''' Sinusoid position encoding table ''' - # TODO: make it with torch instead of numpy - def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - return torch.FloatTensor(sinusoid_table).unsqueeze(0) - -class ModulatedDeformConv(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - deformable_groups=1, - bias=True): - super(ModulatedDeformConv, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.deformable_groups = deformable_groups - self.with_bias = bias - # enable compatibility with nn.Conv2d - self.transposed = False - self.output_padding = _single(0) - - self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.init_weights() - - def init_weights(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - -class ModulatedDeformConvPack(ModulatedDeformConv): - """ - https://github.com/xinntao/EDVR/blob/master/basicsr/models/ops/dcn/deform_conv.py - A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. - - Args: - in_channels (int): Same as nn.Conv2d. - out_channels (int): Same as nn.Conv2d. - kernel_size (int or tuple[int]): Same as nn.Conv2d. - stride (int or tuple[int]): Same as nn.Conv2d. - padding (int or tuple[int]): Same as nn.Conv2d. - dilation (int or tuple[int]): Same as nn.Conv2d. - groups (int): Same as nn.Conv2d. - bias (bool or str): If specified as `auto`, it will be decided by the - norm_cfg. Bias will be set as True if norm_cfg is None, otherwise - False. - """ - - _version = 2 - - def __init__(self, *args, **kwargs): - super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) - - self.conv_offset = nn.Conv2d( - self.in_channels,#self.in_channels+4, - self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], - kernel_size=self.kernel_size, - stride=_pair(self.stride), - padding=_pair(self.padding), - dilation=_pair(self.dilation), - bias=True) - self.init_weights() - - def init_weights(self): - super(ModulatedDeformConvPack, self).init_weights() - if hasattr(self, 'conv_offset'): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x): - # out = self.conv_offset(torch.cat((x,gray_content),dim=1)) - out = self.conv_offset(x) - o1, o2, mask = torch.chunk(out, 3, dim=1) - offset = torch.cat((o1, o2), dim=1) - mask = torch.sigmoid(mask) - - # return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, - # self.groups, self.deformable_groups) - return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, - self.dilation, mask) - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - NEW: use_linear for more efficiency instead of the 1x1 convs - """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, - use_checkpoint=True): - super().__init__() - if exists(context_dim) and not isinstance(context_dim, list): - context_dim = [context_dim] - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) - else: - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] - ) - if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) - else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - self.use_linear = use_linear - # self.dcn_cnn = ModulatedDeformConvPack(inner_dim, - # inner_dim, - # kernel_size=3, - # stride=1, - # padding=1) - - # self.cnnhead = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect') - - # embed_dim=192 - # img_size=64 - # patch_size=8 - # self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, - # in_chans=4, embed_dim=embed_dim, mask_cent=False) - # num_patches = self.patch_embed.num_patches # 2 - - # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) - - # self.cnnhead = CnnHead(embed_dim, num_classes=32, window_size=img_size // patch_size) - - # self.posatnn_block = Block(dim=embed_dim, num_heads=3, mlp_ratio=4., qkv_bias=True, qk_scale=None, - # drop=0., attn_drop=0., norm_layer=nn.LayerNorm, - # init_values=0., use_rpb=True, window_size=img_size // patch_size) - # # self.window_size=8 - # self.norm1=nn.LayerNorm(embed_dim) - - def forward(self, x, context=None,dcn_guide=None): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) - if self.use_linear: - x = self.proj_out(x) - - # x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) - # x = self.cnnhead(x) - # x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') - - # x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() - # print("before",x.shape) - - # if x.shape[1]==4: - # x = self.patch_embed(x) - # print("after PatchEmbed",x.shape) - # x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() - - # x =self.posatnn_block(x) - # x = self.norm1(x) - # print("after norm",x.shape) - - # x = self.cnnhead(x) - - # x = self.dcn_cnn(x,dcn_guide) ## - - # print("after",x.shape) - if not self.use_linear: - x = self.proj_out(x) - - - - return x + x_in - - # res = self.cnnhead(x+x_in) - # return res - - -class SpatialTransformer_dcn(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - NEW: use_linear for more efficiency instead of the 1x1 convs - """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, - use_checkpoint=True): - super().__init__() - if exists(context_dim) and not isinstance(context_dim, list): - context_dim = [context_dim] - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) - else: - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] - ) - if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) - else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - self.use_linear = use_linear - # print(in_channels,inner_dim) - self.dcn_cnn = ModulatedDeformConvPack(inner_dim, - inner_dim, - kernel_size=3, - stride=1, - padding=1) - - # self.cnnhead = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect') - - # embed_dim=192 - # img_size=64 - # patch_size=8 - # self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, - # in_chans=4, embed_dim=embed_dim, mask_cent=False) - # num_patches = self.patch_embed.num_patches # 2 - - # self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) - - # self.cnnhead = CnnHead(embed_dim, num_classes=32, window_size=img_size // patch_size) - - # self.posatnn_block = Block(dim=embed_dim, num_heads=3, mlp_ratio=4., qkv_bias=True, qk_scale=None, - # drop=0., attn_drop=0., norm_layer=nn.LayerNorm, - # init_values=0., use_rpb=True, window_size=img_size // patch_size) - # # self.window_size=8 - # self.norm1=nn.LayerNorm(embed_dim) - - def forward(self, x, context=None,dcn_guide=None): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) - if self.use_linear: - x = self.proj_out(x) - - # x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) - # x = self.cnnhead(x) - # x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') - - # x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() - # print("before",x.shape) - - # if x.shape[1]==4: - # x = self.patch_embed(x) - # print("after PatchEmbed",x.shape) - # x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() - - # x =self.posatnn_block(x) - # x = self.norm1(x) - # print("after norm",x.shape) - - # x = self.cnnhead(x) - x = self.dcn_cnn(x) - # print("after",x.shape) - if not self.use_linear: - x = self.proj_out(x) - - - - return x + x_in - - # res = self.cnnhead(x+x_in) - # return res diff --git a/Control-Color/ldm/modules/diffusionmodules/__init__.py b/Control-Color/ldm/modules/diffusionmodules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/Control-Color/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/Control-Color/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index d596868232bd30ff0e505e8b390adc54d7401150..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/Control-Color/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc deleted file mode 100644 index 03552b00a8ab6e59ef865475b7832cfedfee6045..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/diffusionmodules/__pycache__/model_brefore_dcn.cpython-38.pyc b/Control-Color/ldm/modules/diffusionmodules/__pycache__/model_brefore_dcn.cpython-38.pyc deleted file mode 100644 index fade769ab1118412c73590e103ac2f59efbaaf34..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/diffusionmodules/__pycache__/model_brefore_dcn.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/Control-Color/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc deleted file mode 100644 index 08a266bed7078f12acb92791f41a0153643cec75..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/Control-Color/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc deleted file mode 100644 index 0e42a3c652654cd4e3be1e760aa827e77b4b9a00..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/diffusionmodules/model.py b/Control-Color/ldm/modules/diffusionmodules/model.py deleted file mode 100644 index e2744f63ae24db1570abfec1a7029133cdf1f105..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/diffusionmodules/model.py +++ /dev/null @@ -1,1107 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import torchvision -from torch.nn.modules.utils import _pair, _single -import numpy as np -from einops import rearrange -from typing import Optional, Any - -from ldm.modules.attention import MemoryEfficientCrossAttention - -try: - import xformers - import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - print("No module 'xformers'. Proceeding without it.") - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) - return emb - - -def nonlinearity(x): - # swish - return x*torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x+h - -class ResnetBlock_dcn(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - self.dcn1 = ModulatedDeformConvPack(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - self.dcn2 = ModulatedDeformConvPack(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x,grayx, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - h = self.dcn1(h,grayx)+h - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - h = self.dcn2(h,grayx)+h - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x+h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) - - h_ = self.proj_out(h_) - - return x+h_ - -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.attention_op: Optional[Any] = None - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) - out = self.proj_out(out) - return x+out - - -class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): - b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') - out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) - return x + out - - -def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' - if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": - attn_type = "vanilla-xformers" - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - assert attn_kwargs is None - return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": - print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") - return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels - return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - raise NotImplementedError() - - -class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch*4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", - **ignore_kwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - -class ModulatedDeformConv(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - deformable_groups=1, - bias=True): - super(ModulatedDeformConv, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.deformable_groups = deformable_groups - self.with_bias = bias - # enable compatibility with nn.Conv2d - self.transposed = False - self.output_padding = _single(0) - - self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.init_weights() - - def init_weights(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - - # def forward(self, x, offset, mask): - # return torchvision.ops.con(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, - # self.groups, self.deformable_groups) - - -class ModulatedDeformConvPack(ModulatedDeformConv): - """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. - - Args: - in_channels (int): Same as nn.Conv2d. - out_channels (int): Same as nn.Conv2d. - kernel_size (int or tuple[int]): Same as nn.Conv2d. - stride (int or tuple[int]): Same as nn.Conv2d. - padding (int or tuple[int]): Same as nn.Conv2d. - dilation (int or tuple[int]): Same as nn.Conv2d. - groups (int): Same as nn.Conv2d. - bias (bool or str): If specified as `auto`, it will be decided by the - norm_cfg. Bias will be set as True if norm_cfg is None, otherwise - False. - """ - - _version = 2 - - def __init__(self, *args, **kwargs): - super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) - - self.conv_offset = nn.Conv2d( - self.in_channels+4, - self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], - kernel_size=self.kernel_size, - stride=_pair(self.stride), - padding=_pair(self.padding), - dilation=_pair(self.dilation), - bias=True) - self.init_weights() - - def init_weights(self): - super(ModulatedDeformConvPack, self).init_weights() - if hasattr(self, 'conv_offset'): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x, gray_content): - out = self.conv_offset(torch.cat((x,gray_content),dim=1)) - o1, o2, mask = torch.chunk(out, 3, dim=1) - offset = torch.cat((o1, o2), dim=1) - mask = torch.sigmoid(mask) - - # return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, - # self.groups, self.deformable_groups) - return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, - self.dilation, mask) - - -# class SecondOrderDeformableAlignment(ModulatedDeformConvPack): -# """Second-order deformable alignment module. - -# Args: -# in_channels (int): Same as nn.Conv2d. -# out_channels (int): Same as nn.Conv2d. -# kernel_size (int or tuple[int]): Same as nn.Conv2d. -# stride (int or tuple[int]): Same as nn.Conv2d. -# padding (int or tuple[int]): Same as nn.Conv2d. -# dilation (int or tuple[int]): Same as nn.Conv2d. -# groups (int): Same as nn.Conv2d. -# bias (bool or str): If specified as `auto`, it will be decided by the -# norm_cfg. Bias will be set as True if norm_cfg is None, otherwise -# False. -# max_residue_magnitude (int): The maximum magnitude of the offset -# residue (Eq. 6 in paper). Default: 10. -# """ - -# def __init__(self, *args, **kwargs): -# self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) - -# super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) - -# self.conv_offset = nn.Sequential( -# nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), -# nn.LeakyReLU(negative_slope=0.1, inplace=True), -# nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), -# nn.LeakyReLU(negative_slope=0.1, inplace=True), -# nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), -# nn.LeakyReLU(negative_slope=0.1, inplace=True), -# nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1), -# ) - -# self.init_offset() - -# def init_offset(self): - -# def _constant_init(module, val, bias=0): -# if hasattr(module, 'weight') and module.weight is not None: -# nn.init.constant_(module.weight, val) -# if hasattr(module, 'bias') and module.bias is not None: -# nn.init.constant_(module.bias, bias) - -# _constant_init(self.conv_offset[-1], val=0, bias=0) - -# def forward(self, x, extra_feat, flow_1, flow_2): -# extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) -# out = self.conv_offset(extra_feat) -# o1, o2, mask = torch.chunk(out, 3, dim=1) - -# # offset -# offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) -# offset_1, offset_2 = torch.chunk(offset, 2, dim=1) -# offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1) -# offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1) -# offset = torch.cat([offset_1, offset_2], dim=1) - -# # mask -# mask = torch.sigmoid(mask) - -# return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, -# self.dilation, mask) - -class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) - - self.dcn_in = ModulatedDeformConvPack(block_in, - block_in, - kernel_size=3, - stride=1, - padding=1) - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock_dcn(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock_dcn(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - # else: - # block.append(ResnetBlock_dcn(in_channels=block_in, - # out_channels=block_out, - # temb_channels=self.temb_ch, - # dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - # self.dcn_out = ModulatedDeformConvPack(out_ch, - # out_ch, - # kernel_size=3, - # stride=1, - # padding=1) - - def forward(self, z, gray_content_z): - #assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - # print("h",h.shape) - # print("gray_content_z",gray_content_z.shape) - h = self.dcn_in(h, gray_content_z)+h - - # middle - h = self.mid.block_1(h, gray_content_z,temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, gray_content_z,temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb)#h, gray_content_z,temb - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - # print(h.shape) - # h = self.dcn_out(h,gray_content_z) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) - x = self.attn(x) - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) - - def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: - return x - else: - x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) - return x diff --git a/Control-Color/ldm/modules/diffusionmodules/model_brefore_dcn.py b/Control-Color/ldm/modules/diffusionmodules/model_brefore_dcn.py deleted file mode 100644 index b089eebbe1676d8249005bb9def002ff5180715b..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/diffusionmodules/model_brefore_dcn.py +++ /dev/null @@ -1,852 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange -from typing import Optional, Any - -from ldm.modules.attention import MemoryEfficientCrossAttention - -try: - import xformers - import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - print("No module 'xformers'. Proceeding without it.") - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) - return emb - - -def nonlinearity(x): - # swish - return x*torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x+h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) - - h_ = self.proj_out(h_) - - return x+h_ - -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.attention_op: Optional[Any] = None - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) - out = self.proj_out(out) - return x+out - - -class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): - b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') - out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) - return x + out - - -def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' - if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": - attn_type = "vanilla-xformers" - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - assert attn_kwargs is None - return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": - print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") - return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels - return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - raise NotImplementedError() - - -class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch*4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", - **ignore_kwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) - x = self.attn(x) - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) - - def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: - return x - else: - x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) - return x diff --git a/Control-Color/ldm/modules/diffusionmodules/openaimodel.py b/Control-Color/ldm/modules/diffusionmodules/openaimodel.py deleted file mode 100644 index 390bfc38029f513e85afbb917dee3225d7788f02..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/diffusionmodules/openaimodel.py +++ /dev/null @@ -1,853 +0,0 @@ -from abc import abstractmethod -import math - -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F - -from ldm.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from ldm.modules.attention import SpatialTransformer# -from ldm.modules.attention_dcn_control import SpatialTransformer_dcn -from ldm.util import exists - - -# dummy replace -def convert_module_to_f16(x): - pass - -def convert_module_to_f32(x): - pass - - -## go -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None):#,timestep=None,dcn_guide=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context=context)#,timestep=timestep) - elif isinstance(layer, SpatialTransformer_dcn): - # x = layer(x, context,dcn_guide) - x = layer(x, context) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - -class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) - - def forward(self,x): - return self.up(x) - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - # self.cnnhead = CnnHead(512,num_classes=32,window_size=channels) - def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch - - def _forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - # h = self.cnnhead(h) - return (x + h).reshape(b, c, *spatial) - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - -# class ModulatedDeformConv(nn.Module): -# """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. - -# Args: -# in_channels (int): Same as nn.Conv2d. -# out_channels (int): Same as nn.Conv2d. -# kernel_size (int or tuple[int]): Same as nn.Conv2d. -# stride (int or tuple[int]): Same as nn.Conv2d. -# padding (int or tuple[int]): Same as nn.Conv2d. -# dilation (int or tuple[int]): Same as nn.Conv2d. -# groups (int): Same as nn.Conv2d. -# bias (bool or str): If specified as `auto`, it will be decided by the -# norm_cfg. Bias will be set as True if norm_cfg is None, otherwise -# False. -# """ - -# _version = 2 - -# def __init__(self, *args, **kwargs): -# super(ModulatedDeformConv, self).__init__(*args, **kwargs) - -# self.conv_offset = nn.Conv2d( -# self.in_channels, -# self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], -# kernel_size=self.kernel_size, -# stride=_pair(self.stride), -# padding=_pair(self.padding), -# dilation=_pair(self.dilation), -# bias=True) -# self.init_weights() - -# def init_weights(self): -# super(ModulatedDeformConv, self).init_weights() -# if hasattr(self, 'conv_offset'): -# self.conv_offset.weight.data.zero_() -# self.conv_offset.bias.data.zero_() - -# def forward(self, x): -# out = self.conv_offset(x) -# o1, o2, mask = th.chunk(out, 3, dim=1) -# offset = th.cat((o1, o2), dim=1) -# mask = th.sigmoid(mask) -# return nn.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation,mask, -# self.groups, self.deformable_groups) - -from einops import rearrange -class CnnHead(nn.Module): - def __init__(self, embed_dim, num_classes, window_size): - super().__init__() - self.embed_dim = embed_dim - self.num_classes = num_classes - self.window_size = window_size - - self.cnnhead = nn.Conv2d(embed_dim, num_classes, kernel_size=3, stride=1, padding=1, padding_mode='reflect') - - def forward(self, x): - x = rearrange(x, 'b (p1 p2) c -> b c p1 p2', p1=self.window_size, p2=self.window_size) - x = self.cnnhead(x) - x = rearrange(x, 'b c p1 p2 -> b (p1 p2) c') - return x - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None, - disable_middle_self_attn=False, - use_linear_in_transformer=False, - ): - super().__init__() - if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' - - if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' - from omegaconf.listconfig import ListConfig - if type(context_dim) == ListConfig: - context_dim = list(context_dim) - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' - - if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - if isinstance(num_res_blocks, int): - self.num_res_blocks = len(channel_mult) * [num_res_blocks] - else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") - self.num_res_blocks = num_res_blocks - if disable_self_attentions is not None: - # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not - assert len(disable_self_attentions) == len(channel_mult) - if num_attention_blocks is not None: - assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") - self.label_emb = nn.Linear(1, time_embed_dim) - else: - raise ValueError() - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for nr in range(self.num_res_blocks[level]): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - if exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(self.num_res_blocks[level] + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - if exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if not exists(num_attention_blocks) or i < num_attention_blocks[level]: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ) - ) - # layers.append(CnnHead(ch, ch, window_size=ch // 8)) - if level and i == self.num_res_blocks[level]: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - # layers.append(CnnHead(ch, ch, window_size=ch // 8)) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - if self.predict_codebook_ids: - self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. - """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) diff --git a/Control-Color/ldm/modules/diffusionmodules/util.py b/Control-Color/ldm/modules/diffusionmodules/util.py deleted file mode 100644 index 637363dfe34799e70cfdbcd11445212df9d9ca1f..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/diffusionmodules/util.py +++ /dev/null @@ -1,270 +0,0 @@ -# adopted from -# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -# and -# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -# and -# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py -# -# thanks! - - -import os -import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat - -from ldm.util import instantiate_from_config - - -def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if schedule == "linear": - betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 - ) - - elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) - alphas = timesteps / (1 + cosine_s) * np.pi / 2 - alphas = torch.cos(alphas).pow(2) - alphas = alphas / alphas[0] - betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) - - elif schedule == "sqrt_linear": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 - else: - raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() - - -def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) - else: - raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) - if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') - return sigmas, alphas, alphas_prev - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled()} - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, 'b -> b d', d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class HybridConditioner(nn.Module): - - def __init__(self, c_concat_config, c_crossattn_config): - super().__init__() - self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) - - def forward(self, c_concat, c_crossattn): - c_concat = self.concat_conditioner(c_concat) - c_crossattn = self.crossattn_conditioner(c_crossattn) - return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Control-Color/ldm/modules/distributions/__init__.py b/Control-Color/ldm/modules/distributions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/Control-Color/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/Control-Color/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 591805d0e22bfbabbbc2992d97319d4d5fa0d191..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/Control-Color/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc deleted file mode 100644 index a78fc65a53c0af41200a562f57eb7e2afbd72182..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/distributions/distributions.py b/Control-Color/ldm/modules/distributions/distributions.py deleted file mode 100644 index 3656ce34af754139fbede03d950119e57f089b3a..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/distributions/distributions.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -import numpy as np - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x - - def sample_addhint(self, generator): - latents = torch.randn(self.mean.shape, generator=generator, device='cpu', dtype=self.parameters.dtype).cuda() - x = self.mean + self.std * latents - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) - - def nll(self, sample, dims=[1,2,3]): - if self.deterministic: - return torch.Tensor([0.]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/Control-Color/ldm/modules/ema.py b/Control-Color/ldm/modules/ema.py deleted file mode 100644 index bded25019b9bcbcd0260f0b8185f8c7859ca58c4..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/ema.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') - - self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates - else torch.tensor(-1, dtype=torch.int)) - - for name, p in model.named_parameters(): - if p.requires_grad: - # remove as '.'-character is not allowed in buffers - s_name = name.replace('.', '') - self.m_name2s_name.update({name: s_name}) - self.register_buffer(s_name, p.clone().detach().data) - - self.collected_params = [] - - def reset_num_updates(self): - del self.num_updates - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) - - def forward(self, model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/Control-Color/ldm/modules/encoders/__init__.py b/Control-Color/ldm/modules/encoders/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/Control-Color/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc b/Control-Color/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index fb770670aa01764aa44d88d8a21ca5f069b792cb..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/Control-Color/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index c9742f4d22c32b2aa7383ace945c0665c061f17a..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc b/Control-Color/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc deleted file mode 100644 index 0ecaf58206391f5ef455736777092393049786b3..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/Control-Color/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc deleted file mode 100644 index b57b5063568c3737fdcfa5ff54c881cd8322052c..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/encoders/modules.py b/Control-Color/ldm/modules/encoders/modules.py deleted file mode 100644 index 5abde3faca5c4f0e0d9792c4945cc3d1e715b12b..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/encoders/modules.py +++ /dev/null @@ -1,605 +0,0 @@ -import torch -import torch.nn as nn -from torch.utils.checkpoint import checkpoint - -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPVisionModel, CLIPImageProcessor - -import open_clip -from ldm.util import default, count_params -import kornia -# import clip -from einops import rearrange - -class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - - def encode(self, *args, **kwargs): - raise NotImplementedError - - -class IdentityEncoder(AbstractEncoder): - - def encode(self, x): - return x - - -class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): - super().__init__() - self.key = key - self.embedding = nn.Embedding(n_classes, embed_dim) - self.n_classes = n_classes - self.ucg_rate = ucg_rate - - def forward(self, batch, key=None, disable_dropout=False): - if key is None: - key = self.key - # this is for use in crossattn - c = batch[key][:, None] - if self.ucg_rate > 0. and not disable_dropout: - mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) - c = c.long() - c = self.embedding(c) - return c - - def get_unconditional_conditioning(self, bs, device="cuda"): - uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) - uc = torch.ones((bs,), device=device) * uc_class - uc = {self.key: uc} - return uc - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class FrozenT5Embedder(AbstractEncoder): - """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl - super().__init__() - self.tokenizer = T5Tokenizer.from_pretrained(version) - self.transformer = T5EncoderModel.from_pretrained(version) - self.device = device - self.max_length = max_length # TODO: typical value? - if freeze: - self.freeze() - - def freeze(self): - self.transformer = self.transformer.eval() - #self.train = disabled_train - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens) - - z = outputs.last_hidden_state - return z - - def encode(self, text): - return self(text) - - -class FrozenCLIPEmbedder(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 - super().__init__() - assert layer in self.LAYERS - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - self.layer = layer - self.layer_idx = layer_idx - if layer == "hidden": - assert layer_idx is not None - assert 0 <= abs(layer_idx) <= 12 - - def freeze(self): - self.transformer = self.transformer.eval() - #self.train = disabled_train - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] - else: - z = outputs.hidden_states[self.layer_idx] - # print(z.shape) - return z - - def encode(self, text): - return self(text) - -# class FrozenCLIPDualEmbedder(AbstractEncoder): -# """Uses the CLIP transformer encoder for text (from huggingface)""" -# LAYERS = [ -# "last", -# "pooled", -# "hidden" -# ] -# def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, -# freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 -# super().__init__() -# assert layer in self.LAYERS -# self.tokenizer = CLIPTokenizer.from_pretrained(version) -# self.transformer = CLIPTextModel.from_pretrained(version) -# # self.processor = CLIPImageProcessor.from_pretrained(version) -# # self.imagetransformer = CLIPVisionModel.from_pretrained(version) -# self.ImageEmbedder=FrozenClipImageEmbedder() -# self.device = device -# self.max_length = max_length -# if freeze: -# self.freeze() -# self.layer = layer -# self.layer_idx = layer_idx -# if layer == "hidden": -# assert layer_idx is not None -# assert 0 <= abs(layer_idx) <= 12 - -# def freeze(self): -# self.transformer = self.transformer.eval() -# #self.train = disabled_train -# for name,param in self.named_parameters(): -# if not "imagetransformer" in name and not "imageconv" in name and not "ImageEmbedder" in name: -# # print(name,param) -# param.requires_grad = False -# else: -# param.requires_grad = True -# # print(name) - -# def forward(self, text): -# # print("text:",len(text)) -# # if len(text)==1: -# # txt=text[0] -# # hint_image=None -# # elif len(text)==2: -# # txt,hint_image=text -# txt,hint_image=text -# # print(hint_image.shape) -# batch_encoding = self.tokenizer(txt, truncation=True, max_length=self.max_length, return_length=True, -# return_overflowing_tokens=False, padding="max_length", return_tensors="pt") -# tokens = batch_encoding["input_ids"].to(self.device) -# outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") -# # input_image_batch_encoding = self.processor(input_image,return_tensors="pt") -# # ii_tokens = input_image_batch_encoding["input_ids"].to(self.device) -# # ii_outputs = self.imagetransformer(input_ids=ii_tokens, output_hidden_states=self.layer=="hidden") - -# # hint_image_batch_encoding = self.processor(hint_image,return_tensors="pt") -# # hi_tokens = hint_image_batch_encoding["input_ids"].to(self.device) -# # hi_outputs = self.imagetransformer(input_ids=hi_tokens, output_hidden_states=self.layer=="hidden") - -# # hint_outputs = hi_outputs-ii_outputs -# # if hint_image==None: -# # if self.layer == "last": -# # z = outputs.last_hidden_state -# # elif self.layer == "pooled": -# # z = outputs.pooler_output[:, None, :] -# # else: -# # z = outputs.hidden_states[self.layer_idx] -# # # print("z",z.shape) -# # return z -# hint_outputs=self.ImageEmbedder(hint_image) -# # print("hint_outputs",hint_outputs.shape) -# # print("prompt",outputs.last_hidden_state.shape) -# if self.layer == "last": -# z = torch.cat((outputs.last_hidden_state,hint_outputs.unsqueeze(0)),1)#torch.cat((outputs.last_hidden_state,hint_outputs.last_hidden_state),1)#torch.cat((outputs.last_hidden_state,hint_outputs.unsqueeze(0)),1) -# elif self.layer == "pooled": -# z = torch.cat((outputs.pooler_output[:, None, :],hint_outputs.unsqueeze(0)),1) -# else: -# z = torch.cat((outputs.hidden_states[self.layer_idx],hint_outputs.unsqueeze(0)),1) -# # print("z",z.shape) -# return z - -# def encode(self, text): -# # print(text.shape) -# return self(text) - - -class FrozenCLIPDualEmbedder(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 - super().__init__() - assert layer in self.LAYERS - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version) - # self.processor = CLIPImageProcessor.from_pretrained(version) - # self.imagetransformer = CLIPVisionModel.from_pretrained(version) - self.ImageEmbedder=FrozenClipImageEmbedder() - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - self.layer = layer - self.layer_idx = layer_idx - if layer == "hidden": - assert layer_idx is not None - assert 0 <= abs(layer_idx) <= 12 - print("pooled") - - def freeze(self): - # self.transformer = self.transformer.eval() - #self.train = disabled_train - for name,param in self.named_parameters(): - # print(name) - # if not "imagetransformer" in name and not "imageconv" in name and not "ImageEmbedder" in name: - param.requires_grad = False - # if not "ImageEmbedder" in name: - # # print(name,param) - # param.requires_grad = False - # else: - # param.requires_grad = True - - - def forward(self, text): - # pdb.set_trace() - # print("text:",len(text)) - # if len(text)==1: - # txt=text[0] - # hint_image=None - # elif len(text)==2: - # txt,hint_image=text - txt,hint_image=text - # if hint_image==None: - # batch_encoding = self.tokenizer(txt, truncation=True, max_length=self.max_length, return_length=True, - # return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - # tokens = batch_encoding["input_ids"].to(self.device) - - # outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") - # prompt_outputs=outputs.last_hidden_state - # return prompt_outputs - # else: - # hint_image.requires_grad_(True) - # print(hint_image.shape) - batch_encoding = self.tokenizer(txt, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) - - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") - prompt_outputs=outputs.last_hidden_state - # prompt_outputs=outputs.last_hidden_state.detach().requires_grad_(True) - # prompt_outputs.retain_grad() - # input_image_batch_encoding = self.processor(input_image,return_tensors="pt") - # ii_tokens = input_image_batch_encoding["input_ids"].to(self.device) - # ii_outputs = self.imagetransformer(input_ids=ii_tokens, output_hidden_states=self.layer=="hidden") - - # hint_image_batch_encoding = self.processor(hint_image,return_tensors="pt") - # hi_tokens = hint_image_batch_encoding["input_ids"].to(self.device) - # hi_outputs = self.imagetransformer(input_ids=hi_tokens, output_hidden_states=self.layer=="hidden") - - # hint_outputs = hi_outputs-ii_outputs - # if hint_image==None: - # if self.layer == "last": - # z = outputs.last_hidden_state - # elif self.layer == "pooled": - # z = outputs.pooler_output[:, None, :] - # else: - # z = outputs.hidden_states[self.layer_idx] - # # print("z",z.shape) - # return z - # pdb.set_trace() - outputs = self.ImageEmbedder(hint_image) - # image_embeds = outputs.pooler_output #outputs.image_embeds - image_embeds = outputs.pooler_output - # print(image_embeds.shape) - # last_hidden_state = outputs.last_hidden_state - # pooled_output = outputs.pooler_output - # print("hint_outputs",last_hidden_state.shape) - # print("pooled_output", pooled_output.shape) - # print("prompt",prompt_outputs.shape) - - if self.layer == "last": - # print(prompt_outputs.shape) - # print(image_embeds.shape) - z = torch.cat((prompt_outputs,image_embeds.unsqueeze(1)),1)#,hint_outputs.unsqueeze(0)),1) - # z = torch.cat((prompt_outputs,hint_outputs.last_hidden_state),1)#,hint_outputs.unsqueeze(0)),1) - elif self.layer == "pooled": - z = torch.cat((outputs.pooler_output[:, None, :],hint_outputs.unsqueeze(0)),1) - else: - z = torch.cat((outputs.hidden_states[self.layer_idx],hint_outputs.unsqueeze(0)),1) - - return z - # def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - # freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 - # super().__init__() - # assert layer in self.LAYERS - # # self.processor = CLIPImageProcessor.from_pretrained(version) - # # self.imagetransformer = CLIPVisionModel.from_pretrained(version) - # self.ImageEmbedder=FrozenClipImageEmbedder() - # self.device = device - # self.max_length = max_length - # if freeze: - # self.freeze() - # self.layer = layer - # self.layer_idx = layer_idx - # if layer == "hidden": - # assert layer_idx is not None - # assert 0 <= abs(layer_idx) <= 12 - - # def freeze(self): - # #self.train = disabled_train - # for name,param in self.named_parameters(): - # if not "imagetransformer" in name and not "imageconv" in name and not "ImageEmbedder" in name: - # # print(name,param) - # param.requires_grad = False - # else: - # param.requires_grad = True - # # print(name) - - # def forward(self, txt,hint_image): - # # pdb.set_trace() - # hint_outputs=self.ImageEmbedder(hint_image) - # # print("hint_outputs",hint_outputs.shape) - # # print("prompt",outputs.last_hidden_state.shape) - # if self.layer == "last": - # print(txt.shape) - # print(hint_outputs.last_hidden_state.shape) - # z = torch.cat((txt,hint_outputs.last_hidden_state),1)#,hint_outputs.unsqueeze(0)),1) - # elif self.layer == "pooled": - # z = torch.cat((txt,hint_outputs.unsqueeze(0)),1) - # else: - # z = torch.cat((txt,hint_outputs.unsqueeze(0)),1) - # # print("z",z.shape) - # return z - - def encode(self, text): - - # if isinstance(text, dict): - # txt,hint_image=text['c_crossattn'][0] - # txt=txt - # else: - # txt,hint_image=text - # txt = text - txt, x = text - # if x==None: - # return self((txt,x)) - # print(x.shape) - if len(x.shape) == 3: - x = x[..., None] - - x = rearrange(x, 'b h w c -> b c h w') - x = x.to(memory_format=torch.contiguous_format).float() - x = [x[i] for i in range(x.shape[0])] - return self((txt, x)) - -class FrozenOpenCLIPEmbedder(AbstractEncoder): - """ - Uses the OpenCLIP transformer encoder for text - """ - LAYERS = [ - #"pooled", - "last", - "penultimate" - ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, - freeze=True, layer="last"): - super().__init__() - assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) - del model.visual - self.model = model - - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - self.layer = layer - if self.layer == "last": - self.layer_idx = 0 - elif self.layer == "penultimate": - self.layer_idx = 1 - else: - raise NotImplementedError() - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - tokens = open_clip.tokenize(text) - z = self.encode_with_transformer(tokens.to(self.device)) - return z - - def encode_with_transformer(self, text): - x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] - x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.model.ln_final(x) - return x - - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): - for i, r in enumerate(self.model.transformer.resblocks): - if i == len(self.model.transformer.resblocks) - self.layer_idx: - break - if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - - def encode(self, text): - return self(text) - - -class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", - clip_max_length=77, t5_max_length=77): - super().__init__() - self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) - self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") - - def encode(self, text): - return self(text) - - def forward(self, text): - clip_z = self.clip_encoder.encode(text) - t5_z = self.t5_encoder.encode(text) - return [clip_z, t5_z] - -class FrozenClipImageEmbedder(nn.Module): - """ - Uses the CLIP image encoder. - """ - def __init__( - self, - model='ViT-B/16', #ViT-L/14 - jit=False, - device='cuda' if torch.cuda.is_available() else 'cpu', - antialias=False, - ): - super().__init__() - # self.model, _ = clip.load(name=model, device=device, jit=jit) - # self.model.requires_grad_(True) - self.imageconv = nn.Conv2d(4,3,(3,3),padding=1)#.cuda() - self.antialias = antialias - - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) - self.device = device - self.processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") - self.model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") - # self.imagetransformer = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16") - - - # def preprocess(self, x): - # # normalize to [0,1] - # # print(x.shape) - # # pdb.set_trace() - # x = kornia.geometry.resize(x, (224, 224), - # interpolation='bicubic',align_corners=True, - # antialias=self.antialias) - # # print("after",x.shape) - # # x = (x + 1.) / 2. - # print(x) - # # renormalize according to clip - # x = kornia.enhance.normalize(x, self.mean, self.std) - # # print("after1111111",x.shape) - # return x - - def forward(self, x): - # x is assumed to be in range [-1,1] - # pdb.set_trace() - # with torch.set_grad_enabled(True): - # print("before",x.shape) - # x=self.imageconv(x) - # print("after",x.shape) - # x = x.tolist() - - x = self.processor(x, return_tensors="pt") - # print(x) - # pdb.set_trace() - x['pixel_values'] = x['pixel_values'].to(self.device) - outputs = self.model(**x) - return outputs - -# class FrozenClipImageEmbedder(nn.Module): -# """ -# Uses the CLIP image encoder. -# """ -# def __init__( -# self, -# model='ViT-B/16', -# jit=False, -# device='cuda' if torch.cuda.is_available() else 'cpu', -# antialias=False, -# ): -# super().__init__() -# self.model, _ = clip.load(name=model, device=device, jit=jit) -# # self.imageconv = nn.Conv2d(4,3,(3,3),stride=2) -# self.antialias = antialias - -# self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) -# self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) - -# def preprocess(self, x): -# # normalize to [0,1] -# # print(x.shape) -# x = kornia.geometry.resize(x, (224, 224), -# interpolation='bicubic',align_corners=True, -# antialias=self.antialias) -# # print("after",x.shape) -# x = (x + 1.) / 2. -# # renormalize according to clip -# x = kornia.enhance.normalize(x, self.mean, self.std) -# # print("after1111111",x.shape) -# return x - -# def forward(self, x): -# # x is assumed to be in range [-1,1] -# # x=self.imageconv(x) -# return self.model.encode_image(self.preprocess(x)) - -# class FrozenClipImageEmbedder(nn.Module): -# """ -# Uses the CLIP image encoder. -# """ -# def __init__( -# self, -# model='ViT-B/16', #ViT-L/14 -# jit=False, -# device='cuda' if torch.cuda.is_available() else 'cpu', -# antialias=False, -# ): -# super().__init__() -# self.model, _ = clip.load(name=model, device=device, jit=jit) -# # self.model.requires_grad_(True) -# # self.imageconv = nn.Conv2d(4,3,(3,3),padding=1)#.cuda()#padding=1 #stride=2 -# self.antialias = antialias - -# self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) -# self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) - -# # self.processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") -# self.imagetransformer = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16") - -# def preprocess(self, x): -# # normalize to [0,1] -# # print(x.shape) -# # pdb.set_trace() -# x = kornia.geometry.resize(x, (224, 224), -# interpolation='bicubic',align_corners=True, -# antialias=self.antialias) -# # print("after",x.shape) -# x = (x + 1.) / 2. -# # renormalize according to clip -# x = kornia.enhance.normalize(x, self.mean, self.std) -# # print("after1111111",x.shape) -# return x - -# def forward(self, x): -# # x is assumed to be in range [-1,1] -# # x=self.imageconv(x) -# return self.imagetransformer(self.preprocess(x), output_hidden_states="last"=="hidden") #self.model.encode_image(self.preprocess(x)) diff --git a/Control-Color/ldm/modules/image_degradation/__init__.py b/Control-Color/ldm/modules/image_degradation/__init__.py deleted file mode 100644 index 7836cada81f90ded99c58d5942eea4c3477f58fc..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/image_degradation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr -from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/Control-Color/ldm/modules/image_degradation/bsrgan.py b/Control-Color/ldm/modules/image_degradation/bsrgan.py deleted file mode 100644 index 32ef56169978e550090261cddbcf5eb611a6173b..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/image_degradation/bsrgan.py +++ /dev/null @@ -1,730 +0,0 @@ -# -*- coding: utf-8 -*- -""" -# -------------------------------------------- -# Super-Resolution -# -------------------------------------------- -# -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# From 2019/03--2021/08 -# -------------------------------------------- -""" - -import numpy as np -import cv2 -import torch - -from functools import partial -import random -from scipy import ndimage -import scipy -import scipy.stats as ss -from scipy.interpolate import interp2d -from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util - - -def modcrop_np(img, sf): - ''' - Args: - img: numpy image, WxH or WxHxC - sf: scale factor - Return: - cropped image - ''' - w, h = img.shape[:2] - im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] - - -""" -# -------------------------------------------- -# anisotropic Gaussian kernels -# -------------------------------------------- -""" - - -def analytic_kernel(k): - """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" - k_size = k.shape[0] - # Calculate the big kernels size - big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) - # Loop over the small kernel to fill the big one - for r in range(k_size): - for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k - # Crop the edges of the big kernel to ignore very small values and increase run time of SR - crop = k_size // 2 - cropped_big_k = big_k[crop:-crop, crop:-crop] - # Normalize to 1 - return cropped_big_k / cropped_big_k.sum() - - -def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel - Args: - ksize : e.g., 15, kernel size - theta : [0, pi], rotation angle range - l1 : [0.1,50], scaling of eigenvalues - l2 : [0.1,l1], scaling of eigenvalues - If l1 = l2, will get an isotropic Gaussian kernel. - Returns: - k : kernel - """ - - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) - V = np.array([[v[0], v[1]], [v[1], -v[0]]]) - D = np.array([[l1, 0], [0, l2]]) - Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k - - -def gm_blur_kernel(mean, cov, size=15): - center = size / 2.0 + 0.5 - k = np.zeros([size, size]) - for y in range(size): - for x in range(size): - cy = y - center + 1 - cx = x - center + 1 - k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) - - k = k / np.sum(k) - return k - - -def shift_pixel(x, sf, upper_left=True): - """shift pixel for super-resolution with different scale factors - Args: - x: WxHxC or WxH - sf: scale factor - upper_left: shift direction - """ - h, w = x.shape[:2] - shift = (sf - 1) * 0.5 - xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) - if upper_left: - x1 = xv + shift - y1 = yv + shift - else: - x1 = xv - shift - y1 = yv - shift - - x1 = np.clip(x1, 0, w - 1) - y1 = np.clip(y1, 0, h - 1) - - if x.ndim == 2: - x = interp2d(xv, yv, x)(x1, y1) - if x.ndim == 3: - for i in range(x.shape[-1]): - x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) - - return x - - -def blur(x, k): - ''' - x: image, NxcxHxW - k: kernel, Nx1xhxw - ''' - n, c = x.shape[:2] - p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') - k = k.repeat(1, c, 1, 1) - k = k.view(-1, 1, k.shape[2], k.shape[3]) - x = x.view(1, -1, x.shape[2], x.shape[3]) - x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) - x = x.view(n, c, x.shape[2], x.shape[3]) - - return x - - -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" - # modified version of https://github.com/assafshocher/BlindSR_dataset_generator - # Kai Zhang - # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var - # max_var = 2.5 * sf - """ - # Set random eigen-vals (lambdas) and angle (theta) for COV matrix - lambda_1 = min_var + np.random.rand() * (max_var - min_var) - lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta - noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 - - # Set COV matrix using Lambdas and Theta - LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - SIGMA = Q @ LAMBDA @ Q.T - INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] - - # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) - MU = MU[None, None, :, None] - - # Create meshgrid for Gaussian - [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) - Z = np.stack([X, Y], 2)[:, :, :, None] - - # Calcualte Gaussian for every pixel of the kernel - ZZ = Z - MU - ZZ_t = ZZ.transpose(0, 1, 3, 2) - raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel - - -def fspecial_gaussian(hsize, sigma): - hsize = [hsize, hsize] - siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] - std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) - arg = -(x * x + y * y) / (2 * std * std) - h = np.exp(arg) - h[h < scipy.finfo(float).eps * h.max()] = 0 - sumh = h.sum() - if sumh != 0: - h = h / sumh - return h - - -def fspecial_laplacian(alpha): - alpha = max([0, min([alpha, 1])]) - h1 = alpha / (alpha + 1) - h2 = (1 - alpha) / (alpha + 1) - h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] - h = np.array(h) - return h - - -def fspecial(filter_type, *args, **kwargs): - ''' - python code from: - https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': - return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': - return fspecial_laplacian(*args, **kwargs) - - -""" -# -------------------------------------------- -# degradation models -# -------------------------------------------- -""" - - -def bicubic_degradation(x, sf=3): - ''' - Args: - x: HxWxC image, [0, 1] - sf: down-scale factor - Return: - bicubicly downsampled LR image - ''' - x = util.imresize_np(x, scale=1 / sf) - return x - - -def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2018learning, - title={Learning a single convolutional super-resolution network for multiple degradations}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={3262--3271}, - year={2018} - } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' - x = bicubic_degradation(x, sf=sf) - return x - - -def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2019deep, - title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={1671--1681}, - year={2019} - } - ''' - x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - return x - - -def classical_degradation(x, k, sf=3): - ''' blur + downsampling - Args: - x: HxWxC image, [0, 1]/[0, 255] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) - st = 0 - return x[st::sf, st::sf, ...] - - -def add_sharpening(img, weight=0.5, radius=50, threshold=10): - """USM sharpening. borrowed from real-ESRGAN - Input image: I; Blurry image: B. - 1. K = I + weight * (I - B) - 2. Mask = 1 if abs(I - B) > threshold, else: 0 - 3. Blur mask: - 4. Out = Mask * K + (1 - Mask) * I - Args: - img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. - weight (float): Sharp weight. Default: 1. - radius (float): Kernel size of Gaussian blur. Default: 50. - threshold (int): - """ - if radius % 2 == 0: - radius += 1 - blur = cv2.GaussianBlur(img, (radius, radius), 0) - residual = img - blur - mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') - soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) - - K = img + weight * residual - K = np.clip(K, 0, 1) - return soft_mask * K + (1 - soft_mask) * img - - -def add_blur(img, sf=4): - wd2 = 4.0 + sf - wd = 2.0 + 0.2 * sf - if random.random() < 0.5: - l1 = wd2 * random.random() - l2 = wd2 * random.random() - k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) - else: - k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') - - return img - - -def add_resize(img, sf=4): - rnum = np.random.rand() - if rnum > 0.8: # up - sf1 = random.uniform(1, 2) - elif rnum < 0.7: # down - sf1 = random.uniform(0.5 / sf, 1) - else: - sf1 = 1.0 - img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - return img - - -# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): -# noise_level = random.randint(noise_level1, noise_level2) -# rnum = np.random.rand() -# if rnum > 0.6: # add color Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) -# elif rnum < 0.4: # add grayscale Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) -# else: # add noise -# L = noise_level2 / 255. -# D = np.diag(np.random.rand(3)) -# U = orth(np.random.rand(3, 3)) -# conv = np.dot(np.dot(np.transpose(U), D), U) -# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) -# img = np.clip(img, 0.0, 1.0) -# return img - -def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - rnum = np.random.rand() - if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: # add noise - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_speckle_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - img = np.clip(img, 0.0, 1.0) - rnum = random.random() - if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: - img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. - vals = 10 ** (2 * random.random() + 2.0) # [2, 4] - if random.random() < 0.5: - img = np.random.poisson(img * vals).astype(np.float32) / vals - else: - img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. - noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray - img += noise_gray[:, :, np.newaxis] - img = np.clip(img, 0.0, 1.0) - return img - - -def add_JPEG_noise(img): - quality_factor = random.randint(30, 95) - img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) - img = cv2.imdecode(encimg, 1) - img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) - return img - - -def random_crop(lq, hq, sf=4, lq_patchsize=64): - h, w = lq.shape[:2] - rnd_h = random.randint(0, h - lq_patchsize) - rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] - - rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] - return lq, hq - - -def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - hq = img.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - img = util.imresize_np(img, 1 / 2, True) - img = np.clip(img, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: - img = add_blur(img, sf=sf) - - elif i == 2: - a, b = img.shape[1], img.shape[0] - # downsample2 - if random.random() < 0.75: - sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') - img = img[0::sf, 0::sf, ...] # nearest downsampling - img = np.clip(img, 0.0, 1.0) - - elif i == 3: - # downsample3 - img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - img = add_JPEG_noise(img) - - elif i == 6: - # add processed camera sensor noise - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf_ori, lq_patchsize) - - return img, hq - - -# todo no isp_model? -def degradation_bsrgan_variant(image, sf=4, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - image = util.uint2single(image) - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = image.shape[:2] - - hq = image.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - image = util.imresize_np(image, 1 / 2, True) - image = np.clip(image, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - image = add_blur(image, sf=sf) - - elif i == 1: - image = add_blur(image, sf=sf) - - elif i == 2: - a, b = image.shape[1], image.shape[0] - # downsample2 - if random.random() < 0.75: - sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') - image = image[0::sf, 0::sf, ...] # nearest downsampling - image = np.clip(image, 0.0, 1.0) - - elif i == 3: - # downsample3 - image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - image = np.clip(image, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - image = add_JPEG_noise(image) - - # elif i == 6: - # # add processed camera sensor noise - # if random.random() < isp_prob and isp_model is not None: - # with torch.no_grad(): - # img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - image = add_JPEG_noise(image) - image = util.single2uint(image) - example = {"image":image} - return example - - -# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... -def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): - """ - This is an extended degradation model by combining - the degradation models of BSRGAN and Real-ESRGAN - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - use_shuffle: the degradation shuffle - use_sharp: sharpening the img - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - - h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - if use_sharp: - img = add_sharpening(img) - hq = img.copy() - - if random.random() < shuffle_prob: - shuffle_order = random.sample(range(13), 13) - else: - shuffle_order = list(range(13)) - # local shuffle for noise, JPEG is always the last one - shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) - shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) - - poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 - - for i in shuffle_order: - if i == 0: - img = add_blur(img, sf=sf) - elif i == 1: - img = add_resize(img, sf=sf) - elif i == 2: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 3: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 4: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 5: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - elif i == 6: - img = add_JPEG_noise(img) - elif i == 7: - img = add_blur(img, sf=sf) - elif i == 8: - img = add_resize(img, sf=sf) - elif i == 9: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 10: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 11: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 12: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - else: - print('check the shuffle!') - - # resize to desired size - img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), - interpolation=random.choice([1, 2, 3])) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf, lq_patchsize) - - return img, hq - - -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - diff --git a/Control-Color/ldm/modules/image_degradation/bsrgan_light.py b/Control-Color/ldm/modules/image_degradation/bsrgan_light.py deleted file mode 100644 index 808c7f882cb75e2ba2340d5b55881d11927351f0..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/image_degradation/bsrgan_light.py +++ /dev/null @@ -1,651 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial -import random -from scipy import ndimage -import scipy -import scipy.stats as ss -from scipy.interpolate import interp2d -from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util - -""" -# -------------------------------------------- -# Super-Resolution -# -------------------------------------------- -# -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# From 2019/03--2021/08 -# -------------------------------------------- -""" - -def modcrop_np(img, sf): - ''' - Args: - img: numpy image, WxH or WxHxC - sf: scale factor - Return: - cropped image - ''' - w, h = img.shape[:2] - im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] - - -""" -# -------------------------------------------- -# anisotropic Gaussian kernels -# -------------------------------------------- -""" - - -def analytic_kernel(k): - """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" - k_size = k.shape[0] - # Calculate the big kernels size - big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) - # Loop over the small kernel to fill the big one - for r in range(k_size): - for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k - # Crop the edges of the big kernel to ignore very small values and increase run time of SR - crop = k_size // 2 - cropped_big_k = big_k[crop:-crop, crop:-crop] - # Normalize to 1 - return cropped_big_k / cropped_big_k.sum() - - -def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel - Args: - ksize : e.g., 15, kernel size - theta : [0, pi], rotation angle range - l1 : [0.1,50], scaling of eigenvalues - l2 : [0.1,l1], scaling of eigenvalues - If l1 = l2, will get an isotropic Gaussian kernel. - Returns: - k : kernel - """ - - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) - V = np.array([[v[0], v[1]], [v[1], -v[0]]]) - D = np.array([[l1, 0], [0, l2]]) - Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k - - -def gm_blur_kernel(mean, cov, size=15): - center = size / 2.0 + 0.5 - k = np.zeros([size, size]) - for y in range(size): - for x in range(size): - cy = y - center + 1 - cx = x - center + 1 - k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) - - k = k / np.sum(k) - return k - - -def shift_pixel(x, sf, upper_left=True): - """shift pixel for super-resolution with different scale factors - Args: - x: WxHxC or WxH - sf: scale factor - upper_left: shift direction - """ - h, w = x.shape[:2] - shift = (sf - 1) * 0.5 - xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) - if upper_left: - x1 = xv + shift - y1 = yv + shift - else: - x1 = xv - shift - y1 = yv - shift - - x1 = np.clip(x1, 0, w - 1) - y1 = np.clip(y1, 0, h - 1) - - if x.ndim == 2: - x = interp2d(xv, yv, x)(x1, y1) - if x.ndim == 3: - for i in range(x.shape[-1]): - x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) - - return x - - -def blur(x, k): - ''' - x: image, NxcxHxW - k: kernel, Nx1xhxw - ''' - n, c = x.shape[:2] - p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') - k = k.repeat(1, c, 1, 1) - k = k.view(-1, 1, k.shape[2], k.shape[3]) - x = x.view(1, -1, x.shape[2], x.shape[3]) - x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) - x = x.view(n, c, x.shape[2], x.shape[3]) - - return x - - -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" - # modified version of https://github.com/assafshocher/BlindSR_dataset_generator - # Kai Zhang - # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var - # max_var = 2.5 * sf - """ - # Set random eigen-vals (lambdas) and angle (theta) for COV matrix - lambda_1 = min_var + np.random.rand() * (max_var - min_var) - lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta - noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 - - # Set COV matrix using Lambdas and Theta - LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - SIGMA = Q @ LAMBDA @ Q.T - INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] - - # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) - MU = MU[None, None, :, None] - - # Create meshgrid for Gaussian - [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) - Z = np.stack([X, Y], 2)[:, :, :, None] - - # Calcualte Gaussian for every pixel of the kernel - ZZ = Z - MU - ZZ_t = ZZ.transpose(0, 1, 3, 2) - raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel - - -def fspecial_gaussian(hsize, sigma): - hsize = [hsize, hsize] - siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] - std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) - arg = -(x * x + y * y) / (2 * std * std) - h = np.exp(arg) - h[h < scipy.finfo(float).eps * h.max()] = 0 - sumh = h.sum() - if sumh != 0: - h = h / sumh - return h - - -def fspecial_laplacian(alpha): - alpha = max([0, min([alpha, 1])]) - h1 = alpha / (alpha + 1) - h2 = (1 - alpha) / (alpha + 1) - h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] - h = np.array(h) - return h - - -def fspecial(filter_type, *args, **kwargs): - ''' - python code from: - https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': - return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': - return fspecial_laplacian(*args, **kwargs) - - -""" -# -------------------------------------------- -# degradation models -# -------------------------------------------- -""" - - -def bicubic_degradation(x, sf=3): - ''' - Args: - x: HxWxC image, [0, 1] - sf: down-scale factor - Return: - bicubicly downsampled LR image - ''' - x = util.imresize_np(x, scale=1 / sf) - return x - - -def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2018learning, - title={Learning a single convolutional super-resolution network for multiple degradations}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={3262--3271}, - year={2018} - } - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' - x = bicubic_degradation(x, sf=sf) - return x - - -def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2019deep, - title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={1671--1681}, - year={2019} - } - ''' - x = bicubic_degradation(x, sf=sf) - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - return x - - -def classical_degradation(x, k, sf=3): - ''' blur + downsampling - Args: - x: HxWxC image, [0, 1]/[0, 255] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) - st = 0 - return x[st::sf, st::sf, ...] - - -def add_sharpening(img, weight=0.5, radius=50, threshold=10): - """USM sharpening. borrowed from real-ESRGAN - Input image: I; Blurry image: B. - 1. K = I + weight * (I - B) - 2. Mask = 1 if abs(I - B) > threshold, else: 0 - 3. Blur mask: - 4. Out = Mask * K + (1 - Mask) * I - Args: - img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. - weight (float): Sharp weight. Default: 1. - radius (float): Kernel size of Gaussian blur. Default: 50. - threshold (int): - """ - if radius % 2 == 0: - radius += 1 - blur = cv2.GaussianBlur(img, (radius, radius), 0) - residual = img - blur - mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') - soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) - - K = img + weight * residual - K = np.clip(K, 0, 1) - return soft_mask * K + (1 - soft_mask) * img - - -def add_blur(img, sf=4): - wd2 = 4.0 + sf - wd = 2.0 + 0.2 * sf - - wd2 = wd2/4 - wd = wd/4 - - if random.random() < 0.5: - l1 = wd2 * random.random() - l2 = wd2 * random.random() - k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) - else: - k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') - - return img - - -def add_resize(img, sf=4): - rnum = np.random.rand() - if rnum > 0.8: # up - sf1 = random.uniform(1, 2) - elif rnum < 0.7: # down - sf1 = random.uniform(0.5 / sf, 1) - else: - sf1 = 1.0 - img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - return img - - -# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): -# noise_level = random.randint(noise_level1, noise_level2) -# rnum = np.random.rand() -# if rnum > 0.6: # add color Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) -# elif rnum < 0.4: # add grayscale Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) -# else: # add noise -# L = noise_level2 / 255. -# D = np.diag(np.random.rand(3)) -# U = orth(np.random.rand(3, 3)) -# conv = np.dot(np.dot(np.transpose(U), D), U) -# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) -# img = np.clip(img, 0.0, 1.0) -# return img - -def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - rnum = np.random.rand() - if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: # add noise - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_speckle_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - img = np.clip(img, 0.0, 1.0) - rnum = random.random() - if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: - img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. - vals = 10 ** (2 * random.random() + 2.0) # [2, 4] - if random.random() < 0.5: - img = np.random.poisson(img * vals).astype(np.float32) / vals - else: - img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. - noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray - img += noise_gray[:, :, np.newaxis] - img = np.clip(img, 0.0, 1.0) - return img - - -def add_JPEG_noise(img): - quality_factor = random.randint(80, 95) - img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) - img = cv2.imdecode(encimg, 1) - img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) - return img - - -def random_crop(lq, hq, sf=4, lq_patchsize=64): - h, w = lq.shape[:2] - rnd_h = random.randint(0, h - lq_patchsize) - rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] - - rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] - return lq, hq - - -def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - hq = img.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - img = util.imresize_np(img, 1 / 2, True) - img = np.clip(img, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: - img = add_blur(img, sf=sf) - - elif i == 2: - a, b = img.shape[1], img.shape[0] - # downsample2 - if random.random() < 0.75: - sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') - img = img[0::sf, 0::sf, ...] # nearest downsampling - img = np.clip(img, 0.0, 1.0) - - elif i == 3: - # downsample3 - img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - img = add_JPEG_noise(img) - - elif i == 6: - # add processed camera sensor noise - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf_ori, lq_patchsize) - - return img, hq - - -# todo no isp_model? -def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - image = util.uint2single(image) - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = image.shape[:2] - - hq = image.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - image = util.imresize_np(image, 1 / 2, True) - image = np.clip(image, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - image = add_blur(image, sf=sf) - - # elif i == 1: - # image = add_blur(image, sf=sf) - - if i == 0: - pass - - elif i == 2: - a, b = image.shape[1], image.shape[0] - # downsample2 - if random.random() < 0.8: - sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') - image = image[0::sf, 0::sf, ...] # nearest downsampling - - image = np.clip(image, 0.0, 1.0) - - elif i == 3: - # downsample3 - image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - image = np.clip(image, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - image = add_JPEG_noise(image) - # - # elif i == 6: - # # add processed camera sensor noise - # if random.random() < isp_prob and isp_model is not None: - # with torch.no_grad(): - # img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - image = add_JPEG_noise(image) - image = util.single2uint(image) - if up: - image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then - example = {"image": image} - return example - - - - -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_hq = img - img_lq = deg_fn(img)["image"] - img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), - (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') diff --git a/Control-Color/ldm/modules/image_degradation/utils/test.png b/Control-Color/ldm/modules/image_degradation/utils/test.png deleted file mode 100644 index e720ed04ac7e1e7938d367e692fb6a742c54a24c..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/image_degradation/utils/test.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:92e516278f0d3e85e84cfb55b43338e12d5896a0ee3833aafdf378025457d753 -size 441072 diff --git a/Control-Color/ldm/modules/image_degradation/utils_image.py b/Control-Color/ldm/modules/image_degradation/utils_image.py deleted file mode 100644 index 0175f155ad900ae33c3c46ed87f49b352e3faf98..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/image_degradation/utils_image.py +++ /dev/null @@ -1,916 +0,0 @@ -import os -import math -import random -import numpy as np -import torch -import cv2 -from torchvision.utils import make_grid -from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py - - -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" - - -''' -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 03/Mar/2019 -# -------------------------------------------- -# https://github.com/twhui/SRGAN-pyTorch -# https://github.com/xinntao/BasicSR -# -------------------------------------------- -''' - - -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def get_timestamp(): - return datetime.now().strftime('%y%m%d-%H%M%S') - - -def imshow(x, title=None, cbar=False, figsize=None): - plt.figure(figsize=figsize) - plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') - if title: - plt.title(title) - if cbar: - plt.colorbar() - plt.show() - - -def surf(Z, cmap='rainbow', figsize=None): - plt.figure(figsize=figsize) - ax3 = plt.axes(projection='3d') - - w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) - X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) - plt.show() - - -''' -# -------------------------------------------- -# get image pathes -# -------------------------------------------- -''' - - -def get_image_paths(dataroot): - paths = None # return None if dataroot is None - if dataroot is not None: - paths = sorted(_get_paths_from_images(dataroot)) - return paths - - -def _get_paths_from_images(path): - assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) - images = [] - for dirpath, _, fnames in sorted(os.walk(path)): - for fname in sorted(fnames): - if is_image_file(fname): - img_path = os.path.join(dirpath, fname) - images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) - return images - - -''' -# -------------------------------------------- -# split large images into small images -# -------------------------------------------- -''' - - -def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): - w, h = img.shape[:2] - patches = [] - if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) - for i in w1: - for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) - else: - patches.append(img) - - return patches - - -def imssave(imgs, img_path): - """ - imgs: list, N images of size WxHxC - """ - img_name, ext = os.path.splitext(os.path.basename(img_path)) - - for i, img in enumerate(imgs): - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') - cv2.imwrite(new_path, img) - - -def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): - """ - split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), - and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) - will be splitted. - Args: - original_dataroot: - taget_dataroot: - p_size: size of small images - p_overlap: patch size in training is a good choice - p_max: images with smaller size than (p_max)x(p_max) keep unchanged. - """ - paths = get_image_paths(original_dataroot) - for img_path in paths: - # img_name, ext = os.path.splitext(os.path.basename(img_path)) - img = imread_uint(img_path, n_channels=n_channels) - patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path - -''' -# -------------------------------------------- -# makedir -# -------------------------------------------- -''' - - -def mkdir(path): - if not os.path.exists(path): - os.makedirs(path) - - -def mkdirs(paths): - if isinstance(paths, str): - mkdir(paths) - else: - for path in paths: - mkdir(path) - - -def mkdir_and_rename(path): - if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() - print('Path already exists. Rename it to [{:s}]'.format(new_name)) - os.rename(path, new_name) - os.makedirs(path) - - -''' -# -------------------------------------------- -# read image from path -# opencv is fast, but read BGR numpy image -# -------------------------------------------- -''' - - -# -------------------------------------------- -# get uint8 image of size HxWxn_channles (RGB) -# -------------------------------------------- -def imread_uint(path, n_channels=3): - # input: path - # output: HxWx3(RGB or GGG), or HxWx1 (G) - if n_channels == 1: - img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE - img = np.expand_dims(img, axis=2) # HxWx1 - elif n_channels == 3: - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G - if img.ndim == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG - else: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB - return img - - -# -------------------------------------------- -# matlab's imwrite -# -------------------------------------------- -def imsave(img, img_path): - img = np.squeeze(img) - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - cv2.imwrite(img_path, img) - -def imwrite(img, img_path): - img = np.squeeze(img) - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - cv2.imwrite(img_path, img) - - - -# -------------------------------------------- -# get single image of size HxWxn_channles (BGR) -# -------------------------------------------- -def read_img(path): - # read image by cv2 - # return: Numpy float32, HWC, BGR, [0,1] - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # some images have 4 channels - if img.shape[2] > 3: - img = img[:, :, :3] - return img - - -''' -# -------------------------------------------- -# image format conversion -# -------------------------------------------- -# numpy(single) <---> numpy(unit) -# numpy(single) <---> tensor -# numpy(unit) <---> tensor -# -------------------------------------------- -''' - - -# -------------------------------------------- -# numpy(single) [0, 1] <---> numpy(unit) -# -------------------------------------------- - - -def uint2single(img): - - return np.float32(img/255.) - - -def single2uint(img): - - return np.uint8((img.clip(0, 1)*255.).round()) - - -def uint162single(img): - - return np.float32(img/65535.) - - -def single2uint16(img): - - return np.uint16((img.clip(0, 1)*65535.).round()) - - -# -------------------------------------------- -# numpy(unit) (HxWxC or HxW) <---> tensor -# -------------------------------------------- - - -# convert uint to 4-dimensional torch tensor -def uint2tensor4(img): - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) - - -# convert uint to 3-dimensional torch tensor -def uint2tensor3(img): - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) - - -# convert 2/3/4-dimensional torch tensor to uint -def tensor2uint(img): - img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) - - -# -------------------------------------------- -# numpy(single) (HxWxC) <---> tensor -# -------------------------------------------- - - -# convert single (HxWxC) to 3-dimensional torch tensor -def single2tensor3(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() - - -# convert single (HxWxC) to 4-dimensional torch tensor -def single2tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) - - -# convert torch tensor to single -def tensor2single(img): - img = img.data.squeeze().float().cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - - return img - -# convert torch tensor to single -def tensor2single3(img): - img = img.data.squeeze().float().cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - elif img.ndim == 2: - img = np.expand_dims(img, axis=2) - return img - - -def single2tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) - - -def single32tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) - - -def single42tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() - - -# from skimage.io import imread, imsave -def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' - Converts a torch Tensor into an image Numpy array of BGR channel order - Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order - Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' - tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp - tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] - n_dim = tensor.dim() - if n_dim == 4: - n_img = len(tensor) - img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR - elif n_dim == 3: - img_np = tensor.numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR - elif n_dim == 2: - img_np = tensor.numpy() - else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) - if out_type == np.uint8: - img_np = (img_np * 255.0).round() - # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. - return img_np.astype(out_type) - - -''' -# -------------------------------------------- -# Augmentation, flipe and/or rotate -# -------------------------------------------- -# The following two are enough. -# (1) augmet_img: numpy image of WxHxC or WxH -# (2) augment_img_tensor4: tensor image 1xCxWxH -# -------------------------------------------- -''' - - -def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - if mode == 0: - return img - elif mode == 1: - return np.flipud(np.rot90(img)) - elif mode == 2: - return np.flipud(img) - elif mode == 3: - return np.rot90(img, k=3) - elif mode == 4: - return np.flipud(np.rot90(img, k=2)) - elif mode == 5: - return np.rot90(img) - elif mode == 6: - return np.rot90(img, k=2) - elif mode == 7: - return np.flipud(np.rot90(img, k=3)) - - -def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - if mode == 0: - return img - elif mode == 1: - return img.rot90(1, [2, 3]).flip([2]) - elif mode == 2: - return img.flip([2]) - elif mode == 3: - return img.rot90(3, [2, 3]) - elif mode == 4: - return img.rot90(2, [2, 3]).flip([2]) - elif mode == 5: - return img.rot90(1, [2, 3]) - elif mode == 6: - return img.rot90(2, [2, 3]) - elif mode == 7: - return img.rot90(3, [2, 3]).flip([2]) - - -def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - img_size = img.size() - img_np = img.data.cpu().numpy() - if len(img_size) == 3: - img_np = np.transpose(img_np, (1, 2, 0)) - elif len(img_size) == 4: - img_np = np.transpose(img_np, (2, 3, 1, 0)) - img_np = augment_img(img_np, mode=mode) - img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) - if len(img_size) == 3: - img_tensor = img_tensor.permute(2, 0, 1) - elif len(img_size) == 4: - img_tensor = img_tensor.permute(3, 2, 0, 1) - - return img_tensor.type_as(img) - - -def augment_img_np3(img, mode=0): - if mode == 0: - return img - elif mode == 1: - return img.transpose(1, 0, 2) - elif mode == 2: - return img[::-1, :, :] - elif mode == 3: - img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img - elif mode == 4: - return img[:, ::-1, :] - elif mode == 5: - img = img[:, ::-1, :] - img = img.transpose(1, 0, 2) - return img - elif mode == 6: - img = img[:, ::-1, :] - img = img[::-1, :, :] - return img - elif mode == 7: - img = img[:, ::-1, :] - img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img - - -def augment_imgs(img_list, hflip=True, rot=True): - # horizontal flip OR rotate - hflip = hflip and random.random() < 0.5 - vflip = rot and random.random() < 0.5 - rot90 = rot and random.random() < 0.5 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - return img - - return [_augment(img) for img in img_list] - - -''' -# -------------------------------------------- -# modcrop and shave -# -------------------------------------------- -''' - - -def modcrop(img_in, scale): - # img_in: Numpy, HWC or HW - img = np.copy(img_in) - if img.ndim == 2: - H, W = img.shape - H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] - elif img.ndim == 3: - H, W, C = img.shape - H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] - else: - raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) - return img - - -def shave(img_in, border=0): - # img_in: Numpy, HWC or HW - img = np.copy(img_in) - h, w = img.shape[:2] - img = img[border:h-border, border:w-border] - return img - - -''' -# -------------------------------------------- -# image processing process on numpy image -# channel_convert(in_c, tar_type, img_list): -# rgb2ycbcr(img, only_y=True): -# bgr2ycbcr(img, only_y=True): -# ycbcr2rgb(img): -# -------------------------------------------- -''' - - -def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr - only_y: only return Y channel - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - if only_y: - rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 - else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr - only_y: only return Y channel - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - if only_y: - rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 - else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def channel_convert(in_c, tar_type, img_list): - # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray - gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] - return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y - y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] - return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR - return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] - else: - return img_list - - -''' -# -------------------------------------------- -# metric, PSNR and SSIM -# -------------------------------------------- -''' - - -# -------------------------------------------- -# PSNR -# -------------------------------------------- -def calculate_psnr(img1, img2, border=0): - # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() - if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') - h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) - if mse == 0: - return float('inf') - return 20 * math.log10(255.0 / math.sqrt(mse)) - - -# -------------------------------------------- -# SSIM -# -------------------------------------------- -def calculate_ssim(img1, img2, border=0): - '''calculate SSIM - the same outputs as MATLAB's - img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() - if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') - h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] - - if img1.ndim == 2: - return ssim(img1, img2) - elif img1.ndim == 3: - if img1.shape[2] == 3: - ssims = [] - for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) - return np.array(ssims).mean() - elif img1.shape[2] == 1: - return ssim(np.squeeze(img1), np.squeeze(img2)) - else: - raise ValueError('Wrong input image dimensions.') - - -def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - kernel = cv2.getGaussianKernel(11, 1.5) - window = np.outer(kernel, kernel.transpose()) - - mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid - mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] - mu1_sq = mu1**2 - mu2_sq = mu2**2 - mu1_mu2 = mu1 * mu2 - sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq - sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq - sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) - return ssim_map.mean() - - -''' -# -------------------------------------------- -# matlab's bicubic imresize (numpy and torch) [0, 1] -# -------------------------------------------- -''' - - -# matlab 'imresize' function, now only support 'bicubic' -def cubic(x): - absx = torch.abs(x) - absx2 = absx**2 - absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) - - -def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): - if (scale < 1) and (antialiasing): - # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width - kernel_width = kernel_width / scale - - # Output-space coordinates - x = torch.linspace(1, out_length, out_length) - - # Input-space coordinates. Calculate the inverse mapping such that 0.5 - # in output space maps to 0.5 in input space, and 0.5+scale in output - # space maps to 1.5 in input space. - u = x / scale + 0.5 * (1 - 1 / scale) - - # What is the left-most pixel that can be involved in the computation? - left = torch.floor(u - kernel_width / 2) - - # What is the maximum number of pixels that can be involved in the - # computation? Note: it's OK to use an extra pixel here; if the - # corresponding weights are all zero, it will be eliminated at the end - # of this function. - P = math.ceil(kernel_width) + 2 - - # The indices of the input pixels involved in computing the k-th output - # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) - - # The weights used to compute the k-th output pixel are in row k of the - # weights matrix. - distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices - # apply cubic kernel - if (scale < 1) and (antialiasing): - weights = scale * cubic(distance_to_center * scale) - else: - weights = cubic(distance_to_center) - # Normalize the weights matrix so that each row sums to 1. - weights_sum = torch.sum(weights, 1).view(out_length, 1) - weights = weights / weights_sum.expand(out_length, P) - - # If a column in weights is all zero, get rid of it. only consider the first and last column. - weights_zero_tmp = torch.sum((weights == 0), 0) - if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): - indices = indices.narrow(1, 1, P - 2) - weights = weights.narrow(1, 1, P - 2) - if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): - indices = indices.narrow(1, 0, P - 2) - weights = weights.narrow(1, 0, P - 2) - weights = weights.contiguous() - indices = indices.contiguous() - sym_len_s = -indices.min() + 1 - sym_len_e = indices.max() - in_length - indices = indices + sym_len_s - 1 - return weights, indices, int(sym_len_s), int(sym_len_e) - - -# -------------------------------------------- -# imresize for tensor image [0, 1] -# -------------------------------------------- -def imresize(img, scale, antialiasing=True): - # Now the scale should be the same for H and W - # input: img: pytorch tensor, CHW or HW [0,1] - # output: CHW or HW [0,1] w/o round - need_squeeze = True if img.dim() == 2 else False - if need_squeeze: - img.unsqueeze_(0) - in_C, in_H, in_W = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) - kernel_width = 4 - kernel = 'cubic' - - # Return the desired dimension order for performing the resize. The - # strategy is to perform the resize first along the dimension with the - # smallest scale factor. - # Now we do not support this. - - # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) - # process H dimension - # symmetric copying - img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) - img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) - - sym_patch = img[:, :sym_len_Hs, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) - - sym_patch = img[:, -sym_len_He:, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) - - out_1 = torch.FloatTensor(in_C, out_H, in_W) - kernel_width = weights_H.size(1) - for i in range(out_H): - idx = int(indices_H[i][0]) - for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) - - # process W dimension - # symmetric copying - out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) - out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) - - sym_patch = out_1[:, :, :sym_len_Ws] - inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(2, inv_idx) - out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) - - sym_patch = out_1[:, :, -sym_len_We:] - inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(2, inv_idx) - out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) - - out_2 = torch.FloatTensor(in_C, out_H, out_W) - kernel_width = weights_W.size(1) - for i in range(out_W): - idx = int(indices_W[i][0]) - for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) - if need_squeeze: - out_2.squeeze_() - return out_2 - - -# -------------------------------------------- -# imresize for numpy image [0, 1] -# -------------------------------------------- -def imresize_np(img, scale, antialiasing=True): - # Now the scale should be the same for H and W - # input: img: Numpy, HWC or HW [0,1] - # output: HWC or HW [0,1] w/o round - img = torch.from_numpy(img) - need_squeeze = True if img.dim() == 2 else False - if need_squeeze: - img.unsqueeze_(2) - - in_H, in_W, in_C = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) - kernel_width = 4 - kernel = 'cubic' - - # Return the desired dimension order for performing the resize. The - # strategy is to perform the resize first along the dimension with the - # smallest scale factor. - # Now we do not support this. - - # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) - # process H dimension - # symmetric copying - img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) - img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) - - sym_patch = img[:sym_len_Hs, :, :] - inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(0, inv_idx) - img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) - - sym_patch = img[-sym_len_He:, :, :] - inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(0, inv_idx) - img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) - - out_1 = torch.FloatTensor(out_H, in_W, in_C) - kernel_width = weights_H.size(1) - for i in range(out_H): - idx = int(indices_H[i][0]) - for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) - - # process W dimension - # symmetric copying - out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) - out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) - - sym_patch = out_1[:, :sym_len_Ws, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) - - sym_patch = out_1[:, -sym_len_We:, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) - - out_2 = torch.FloatTensor(out_H, out_W, in_C) - kernel_width = weights_W.size(1) - for i in range(out_W): - idx = int(indices_W[i][0]) - for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) - if need_squeeze: - out_2.squeeze_() - - return out_2.numpy() - - -if __name__ == '__main__': - print('---') -# img = imread_uint('test.bmp', 3) -# img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/Control-Color/ldm/modules/losses/__init__.py b/Control-Color/ldm/modules/losses/__init__.py deleted file mode 100644 index 62fca5bce4c771b1e06c1cc0c6492842565a7817..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/losses/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -from ldm.modules.losses.vqperceptual import VQLPIPSWithDiscriminator \ No newline at end of file diff --git a/Control-Color/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc b/Control-Color/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index cc91d1b40d078546b88b6521bd53ffb7b5ed8b9b..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc b/Control-Color/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc deleted file mode 100644 index e75f5573d668c9cfc4ecee4505148dd79f90349b..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/losses/__pycache__/vqperceptual.cpython-38.pyc b/Control-Color/ldm/modules/losses/__pycache__/vqperceptual.cpython-38.pyc deleted file mode 100644 index ac36dfc60f917286f2d2678520cd3670738ffc58..0000000000000000000000000000000000000000 Binary files a/Control-Color/ldm/modules/losses/__pycache__/vqperceptual.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/ldm/modules/losses/contperceptual.py b/Control-Color/ldm/modules/losses/contperceptual.py deleted file mode 100644 index 44b4505c035a33eab035f922934e188818772ebb..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/losses/contperceptual.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import torch.nn as nn - -from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? - -#https://github.com/IceClear/StableSR/blob/main/ldm/modules/losses/contperceptual.py - -class LPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_loss="hinge"): - - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - # output log variance - self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm - ).apply(weights_init) - self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, inputs, reconstructions, posteriors, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", - weights=None, return_dic=False): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - - nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights*nll_loss - weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.mean(nll_loss) / nll_loss.shape[0] - if self.kl_weight>0: - kl_loss = posteriors.kl() - kl_loss = torch.mean(kl_loss) / kl_loss.shape[0] - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - if self.disc_factor > 0.0: - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - # assert not self.training - d_weight = torch.tensor(1.0) * self.discriminator_weight - else: - # d_weight = torch.tensor(0.0) - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - if self.kl_weight>0: - loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - if return_dic: - loss_dic = {} - loss_dic['total_loss'] = loss.clone().detach().mean() - loss_dic['logvar'] = self.logvar.detach() - loss_dic['kl_loss'] = kl_loss.detach().mean() - loss_dic['nll_loss'] = nll_loss.detach().mean() - loss_dic['rec_loss'] = rec_loss.detach().mean() - loss_dic['d_weight'] = d_weight.detach() - loss_dic['disc_factor'] = torch.tensor(disc_factor) - loss_dic['g_loss'] = g_loss.detach().mean() - else: - loss = weighted_nll_loss + d_weight * disc_factor * g_loss - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - if return_dic: - loss_dic = {} - loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean() - loss_dic["{}/logvar".format(split)] = self.logvar.detach() - loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean() - loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean() - loss_dic['d_weight'.format(split)] = d_weight.detach() - loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor) - loss_dic['g_loss'.format(split)] = g_loss.detach().mean() - - if return_dic: - return loss, log, loss_dic - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - - if return_dic: - loss_dic = {} - loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean() - loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean() - loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean() - return d_loss, log, loss_dic - - return d_loss, log \ No newline at end of file diff --git a/Control-Color/ldm/modules/losses/vqperceptual.py b/Control-Color/ldm/modules/losses/vqperceptual.py deleted file mode 100644 index 66306c0cf02bddea1cd7960c8ce29a2c7da7ea39..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/losses/vqperceptual.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from taming.modules.losses.lpips import LPIPS -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init - - -class DummyLoss(nn.Module): - def __init__(self): - super().__init__() - - -def adopt_weight(weight, global_step, threshold=0, value=0.): - if global_step < threshold: - weight = value - return weight - - -def hinge_d_loss(logits_real, logits_fake): - loss_real = torch.mean(F.relu(1. - logits_real)) - loss_fake = torch.mean(F.relu(1. + logits_fake)) - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - - -def vanilla_d_loss(logits_real, logits_fake): - d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) + - torch.mean(torch.nn.functional.softplus(logits_fake))) - return d_loss - - -class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge"): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.codebook_weight = codebook_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) - self.discriminator_iter_start = disc_start - if disc_loss == "hinge": - self.disc_loss = hinge_d_loss - elif disc_loss == "vanilla": - self.disc_loss = vanilla_d_loss - else: - raise ValueError(f"Unknown GAN loss '{disc_loss}'.") - print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train"): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - else: - p_loss = torch.tensor([0.0]) - - nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - nll_loss = torch.mean(nll_loss) - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log \ No newline at end of file diff --git a/Control-Color/ldm/modules/midas/__init__.py b/Control-Color/ldm/modules/midas/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/Control-Color/ldm/modules/midas/api.py b/Control-Color/ldm/modules/midas/api.py deleted file mode 100644 index b58ebbffd942a2fc22264f0ab47e400c26b9f41c..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/api.py +++ /dev/null @@ -1,170 +0,0 @@ -# based on https://github.com/isl-org/MiDaS - -import cv2 -import torch -import torch.nn as nn -from torchvision.transforms import Compose - -from ldm.modules.midas.midas.dpt_depth import DPTDepthModel -from ldm.modules.midas.midas.midas_net import MidasNet -from ldm.modules.midas.midas.midas_net_custom import MidasNet_small -from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet - - -ISL_PATHS = { - "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", - "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", - "midas_v21": "", - "midas_v21_small": "", -} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def load_midas_transform(model_type): - # https://github.com/isl-org/MiDaS/blob/master/run.py - # load transform only - if model_type == "dpt_large": # DPT-Large - net_w, net_h = 384, 384 - resize_mode = "minimal" - normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - - elif model_type == "dpt_hybrid": # DPT-Hybrid - net_w, net_h = 384, 384 - resize_mode = "minimal" - normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - - elif model_type == "midas_v21": - net_w, net_h = 384, 384 - resize_mode = "upper_bound" - normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - - elif model_type == "midas_v21_small": - net_w, net_h = 256, 256 - resize_mode = "upper_bound" - normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - - else: - assert False, f"model_type '{model_type}' not implemented, use: --model_type large" - - transform = Compose( - [ - Resize( - net_w, - net_h, - resize_target=None, - keep_aspect_ratio=True, - ensure_multiple_of=32, - resize_method=resize_mode, - image_interpolation_method=cv2.INTER_CUBIC, - ), - normalization, - PrepareForNet(), - ] - ) - - return transform - - -def load_model(model_type): - # https://github.com/isl-org/MiDaS/blob/master/run.py - # load network - model_path = ISL_PATHS[model_type] - if model_type == "dpt_large": # DPT-Large - model = DPTDepthModel( - path=model_path, - backbone="vitl16_384", - non_negative=True, - ) - net_w, net_h = 384, 384 - resize_mode = "minimal" - normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - - elif model_type == "dpt_hybrid": # DPT-Hybrid - model = DPTDepthModel( - path=model_path, - backbone="vitb_rn50_384", - non_negative=True, - ) - net_w, net_h = 384, 384 - resize_mode = "minimal" - normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - - elif model_type == "midas_v21": - model = MidasNet(model_path, non_negative=True) - net_w, net_h = 384, 384 - resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) - - elif model_type == "midas_v21_small": - model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, - non_negative=True, blocks={'expand': True}) - net_w, net_h = 256, 256 - resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) - - else: - print(f"model_type '{model_type}' not implemented, use: --model_type large") - assert False - - transform = Compose( - [ - Resize( - net_w, - net_h, - resize_target=None, - keep_aspect_ratio=True, - ensure_multiple_of=32, - resize_method=resize_mode, - image_interpolation_method=cv2.INTER_CUBIC, - ), - normalization, - PrepareForNet(), - ] - ) - - return model.eval(), transform - - -class MiDaSInference(nn.Module): - MODEL_TYPES_TORCH_HUB = [ - "DPT_Large", - "DPT_Hybrid", - "MiDaS_small" - ] - MODEL_TYPES_ISL = [ - "dpt_large", - "dpt_hybrid", - "midas_v21", - "midas_v21_small", - ] - - def __init__(self, model_type): - super().__init__() - assert (model_type in self.MODEL_TYPES_ISL) - model, _ = load_model(model_type) - self.model = model - self.model.train = disabled_train - - def forward(self, x): - # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array - # NOTE: we expect that the correct transform has been called during dataloading. - with torch.no_grad(): - prediction = self.model(x) - prediction = torch.nn.functional.interpolate( - prediction.unsqueeze(1), - size=x.shape[2:], - mode="bicubic", - align_corners=False, - ) - assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) - return prediction - diff --git a/Control-Color/ldm/modules/midas/midas/__init__.py b/Control-Color/ldm/modules/midas/midas/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/Control-Color/ldm/modules/midas/midas/base_model.py b/Control-Color/ldm/modules/midas/midas/base_model.py deleted file mode 100644 index 5cf430239b47ec5ec07531263f26f5c24a2311cd..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/midas/base_model.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - - -class BaseModel(torch.nn.Module): - def load(self, path): - """Load model from file. - - Args: - path (str): file path - """ - parameters = torch.load(path, map_location=torch.device('cpu')) - - if "optimizer" in parameters: - parameters = parameters["model"] - - self.load_state_dict(parameters) diff --git a/Control-Color/ldm/modules/midas/midas/blocks.py b/Control-Color/ldm/modules/midas/midas/blocks.py deleted file mode 100644 index 2145d18fa98060a618536d9a64fe6589e9be4f78..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/midas/blocks.py +++ /dev/null @@ -1,342 +0,0 @@ -import torch -import torch.nn as nn - -from .vit import ( - _make_pretrained_vitb_rn50_384, - _make_pretrained_vitl16_384, - _make_pretrained_vitb16_384, - forward_vit, -) - -def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): - if backbone == "vitl16_384": - pretrained = _make_pretrained_vitl16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) - scratch = _make_scratch( - [256, 512, 1024, 1024], features, groups=groups, expand=expand - ) # ViT-L/16 - 85.0% Top1 (backbone) - elif backbone == "vitb_rn50_384": - pretrained = _make_pretrained_vitb_rn50_384( - use_pretrained, - hooks=hooks, - use_vit_only=use_vit_only, - use_readout=use_readout, - ) - scratch = _make_scratch( - [256, 512, 768, 768], features, groups=groups, expand=expand - ) # ViT-H/16 - 85.0% Top1 (backbone) - elif backbone == "vitb16_384": - pretrained = _make_pretrained_vitb16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) - scratch = _make_scratch( - [96, 192, 384, 768], features, groups=groups, expand=expand - ) # ViT-B/16 - 84.6% Top1 (backbone) - elif backbone == "resnext101_wsl": - pretrained = _make_pretrained_resnext101_wsl(use_pretrained) - scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 - elif backbone == "efficientnet_lite3": - pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) - scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 - else: - print(f"Backbone '{backbone}' not implemented") - assert False - - return pretrained, scratch - - -def _make_scratch(in_shape, out_shape, groups=1, expand=False): - scratch = nn.Module() - - out_shape1 = out_shape - out_shape2 = out_shape - out_shape3 = out_shape - out_shape4 = out_shape - if expand==True: - out_shape1 = out_shape - out_shape2 = out_shape*2 - out_shape3 = out_shape*4 - out_shape4 = out_shape*8 - - scratch.layer1_rn = nn.Conv2d( - in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer2_rn = nn.Conv2d( - in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer3_rn = nn.Conv2d( - in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer4_rn = nn.Conv2d( - in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - - return scratch - - -def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): - efficientnet = torch.hub.load( - "rwightman/gen-efficientnet-pytorch", - "tf_efficientnet_lite3", - pretrained=use_pretrained, - exportable=exportable - ) - return _make_efficientnet_backbone(efficientnet) - - -def _make_efficientnet_backbone(effnet): - pretrained = nn.Module() - - pretrained.layer1 = nn.Sequential( - effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] - ) - pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) - pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) - pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) - - return pretrained - - -def _make_resnet_backbone(resnet): - pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 - ) - - pretrained.layer2 = resnet.layer2 - pretrained.layer3 = resnet.layer3 - pretrained.layer4 = resnet.layer4 - - return pretrained - - -def _make_pretrained_resnext101_wsl(use_pretrained): - resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") - return _make_resnet_backbone(resnet) - - - -class Interpolate(nn.Module): - """Interpolation module. - """ - - def __init__(self, scale_factor, mode, align_corners=False): - """Init. - - Args: - scale_factor (float): scaling - mode (str): interpolation mode - """ - super(Interpolate, self).__init__() - - self.interp = nn.functional.interpolate - self.scale_factor = scale_factor - self.mode = mode - self.align_corners = align_corners - - def forward(self, x): - """Forward pass. - - Args: - x (tensor): input - - Returns: - tensor: interpolated data - """ - - x = self.interp( - x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners - ) - - return x - - -class ResidualConvUnit(nn.Module): - """Residual convolution module. - """ - - def __init__(self, features): - """Init. - - Args: - features (int): number of features - """ - super().__init__() - - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) - - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) - - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - """Forward pass. - - Args: - x (tensor): input - - Returns: - tensor: output - """ - out = self.relu(x) - out = self.conv1(out) - out = self.relu(out) - out = self.conv2(out) - - return out + x - - -class FeatureFusionBlock(nn.Module): - """Feature fusion block. - """ - - def __init__(self, features): - """Init. - - Args: - features (int): number of features - """ - super(FeatureFusionBlock, self).__init__() - - self.resConfUnit1 = ResidualConvUnit(features) - self.resConfUnit2 = ResidualConvUnit(features) - - def forward(self, *xs): - """Forward pass. - - Returns: - tensor: output - """ - output = xs[0] - - if len(xs) == 2: - output += self.resConfUnit1(xs[1]) - - output = self.resConfUnit2(output) - - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=True - ) - - return output - - - - -class ResidualConvUnit_custom(nn.Module): - """Residual convolution module. - """ - - def __init__(self, features, activation, bn): - """Init. - - Args: - features (int): number of features - """ - super().__init__() - - self.bn = bn - - self.groups=1 - - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - - if self.bn==True: - self.bn1 = nn.BatchNorm2d(features) - self.bn2 = nn.BatchNorm2d(features) - - self.activation = activation - - self.skip_add = nn.quantized.FloatFunctional() - - def forward(self, x): - """Forward pass. - - Args: - x (tensor): input - - Returns: - tensor: output - """ - - out = self.activation(x) - out = self.conv1(out) - if self.bn==True: - out = self.bn1(out) - - out = self.activation(out) - out = self.conv2(out) - if self.bn==True: - out = self.bn2(out) - - if self.groups > 1: - out = self.conv_merge(out) - - return self.skip_add.add(out, x) - - # return out + x - - -class FeatureFusionBlock_custom(nn.Module): - """Feature fusion block. - """ - - def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): - """Init. - - Args: - features (int): number of features - """ - super(FeatureFusionBlock_custom, self).__init__() - - self.deconv = deconv - self.align_corners = align_corners - - self.groups=1 - - self.expand = expand - out_features = features - if self.expand==True: - out_features = features//2 - - self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) - - self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) - self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) - - self.skip_add = nn.quantized.FloatFunctional() - - def forward(self, *xs): - """Forward pass. - - Returns: - tensor: output - """ - output = xs[0] - - if len(xs) == 2: - res = self.resConfUnit1(xs[1]) - output = self.skip_add.add(output, res) - # output += res - - output = self.resConfUnit2(output) - - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=self.align_corners - ) - - output = self.out_conv(output) - - return output - diff --git a/Control-Color/ldm/modules/midas/midas/dpt_depth.py b/Control-Color/ldm/modules/midas/midas/dpt_depth.py deleted file mode 100644 index 4e9aab5d2767dffea39da5b3f30e2798688216f1..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/midas/dpt_depth.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .base_model import BaseModel -from .blocks import ( - FeatureFusionBlock, - FeatureFusionBlock_custom, - Interpolate, - _make_encoder, - forward_vit, -) - - -def _make_fusion_block(features, use_bn): - return FeatureFusionBlock_custom( - features, - nn.ReLU(False), - deconv=False, - bn=use_bn, - expand=False, - align_corners=True, - ) - - -class DPT(BaseModel): - def __init__( - self, - head, - features=256, - backbone="vitb_rn50_384", - readout="project", - channels_last=False, - use_bn=False, - ): - - super(DPT, self).__init__() - - self.channels_last = channels_last - - hooks = { - "vitb_rn50_384": [0, 1, 8, 11], - "vitb16_384": [2, 5, 8, 11], - "vitl16_384": [5, 11, 17, 23], - } - - # Instantiate backbone and reassemble blocks - self.pretrained, self.scratch = _make_encoder( - backbone, - features, - False, # Set to true of you want to train from scratch, uses ImageNet weights - groups=1, - expand=False, - exportable=False, - hooks=hooks[backbone], - use_readout=readout, - ) - - self.scratch.refinenet1 = _make_fusion_block(features, use_bn) - self.scratch.refinenet2 = _make_fusion_block(features, use_bn) - self.scratch.refinenet3 = _make_fusion_block(features, use_bn) - self.scratch.refinenet4 = _make_fusion_block(features, use_bn) - - self.scratch.output_conv = head - - - def forward(self, x): - if self.channels_last == True: - x.contiguous(memory_format=torch.channels_last) - - layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) - - layer_1_rn = self.scratch.layer1_rn(layer_1) - layer_2_rn = self.scratch.layer2_rn(layer_2) - layer_3_rn = self.scratch.layer3_rn(layer_3) - layer_4_rn = self.scratch.layer4_rn(layer_4) - - path_4 = self.scratch.refinenet4(layer_4_rn) - path_3 = self.scratch.refinenet3(path_4, layer_3_rn) - path_2 = self.scratch.refinenet2(path_3, layer_2_rn) - path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - - out = self.scratch.output_conv(path_1) - - return out - - -class DPTDepthModel(DPT): - def __init__(self, path=None, non_negative=True, **kwargs): - features = kwargs["features"] if "features" in kwargs else 256 - - head = nn.Sequential( - nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), - Interpolate(scale_factor=2, mode="bilinear", align_corners=True), - nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), - nn.ReLU(True) if non_negative else nn.Identity(), - nn.Identity(), - ) - - super().__init__(head, **kwargs) - - if path is not None: - self.load(path) - - def forward(self, x): - return super().forward(x).squeeze(dim=1) - diff --git a/Control-Color/ldm/modules/midas/midas/midas_net.py b/Control-Color/ldm/modules/midas/midas/midas_net.py deleted file mode 100644 index 8a954977800b0a0f48807e80fa63041910e33c1f..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/midas/midas_net.py +++ /dev/null @@ -1,76 +0,0 @@ -"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. -This file contains code that is adapted from -https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py -""" -import torch -import torch.nn as nn - -from .base_model import BaseModel -from .blocks import FeatureFusionBlock, Interpolate, _make_encoder - - -class MidasNet(BaseModel): - """Network for monocular depth estimation. - """ - - def __init__(self, path=None, features=256, non_negative=True): - """Init. - - Args: - path (str, optional): Path to saved model. Defaults to None. - features (int, optional): Number of features. Defaults to 256. - backbone (str, optional): Backbone network for encoder. Defaults to resnet50 - """ - print("Loading weights: ", path) - - super(MidasNet, self).__init__() - - use_pretrained = False if path is None else True - - self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) - - self.scratch.refinenet4 = FeatureFusionBlock(features) - self.scratch.refinenet3 = FeatureFusionBlock(features) - self.scratch.refinenet2 = FeatureFusionBlock(features) - self.scratch.refinenet1 = FeatureFusionBlock(features) - - self.scratch.output_conv = nn.Sequential( - nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), - Interpolate(scale_factor=2, mode="bilinear"), - nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), - nn.ReLU(True) if non_negative else nn.Identity(), - ) - - if path: - self.load(path) - - def forward(self, x): - """Forward pass. - - Args: - x (tensor): input data (image) - - Returns: - tensor: depth - """ - - layer_1 = self.pretrained.layer1(x) - layer_2 = self.pretrained.layer2(layer_1) - layer_3 = self.pretrained.layer3(layer_2) - layer_4 = self.pretrained.layer4(layer_3) - - layer_1_rn = self.scratch.layer1_rn(layer_1) - layer_2_rn = self.scratch.layer2_rn(layer_2) - layer_3_rn = self.scratch.layer3_rn(layer_3) - layer_4_rn = self.scratch.layer4_rn(layer_4) - - path_4 = self.scratch.refinenet4(layer_4_rn) - path_3 = self.scratch.refinenet3(path_4, layer_3_rn) - path_2 = self.scratch.refinenet2(path_3, layer_2_rn) - path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - - out = self.scratch.output_conv(path_1) - - return torch.squeeze(out, dim=1) diff --git a/Control-Color/ldm/modules/midas/midas/midas_net_custom.py b/Control-Color/ldm/modules/midas/midas/midas_net_custom.py deleted file mode 100644 index 50e4acb5e53d5fabefe3dde16ab49c33c2b7797c..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/midas/midas_net_custom.py +++ /dev/null @@ -1,128 +0,0 @@ -"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. -This file contains code that is adapted from -https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py -""" -import torch -import torch.nn as nn - -from .base_model import BaseModel -from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder - - -class MidasNet_small(BaseModel): - """Network for monocular depth estimation. - """ - - def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, - blocks={'expand': True}): - """Init. - - Args: - path (str, optional): Path to saved model. Defaults to None. - features (int, optional): Number of features. Defaults to 256. - backbone (str, optional): Backbone network for encoder. Defaults to resnet50 - """ - print("Loading weights: ", path) - - super(MidasNet_small, self).__init__() - - use_pretrained = False if path else True - - self.channels_last = channels_last - self.blocks = blocks - self.backbone = backbone - - self.groups = 1 - - features1=features - features2=features - features3=features - features4=features - self.expand = False - if "expand" in self.blocks and self.blocks['expand'] == True: - self.expand = True - features1=features - features2=features*2 - features3=features*4 - features4=features*8 - - self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) - - self.scratch.activation = nn.ReLU(False) - - self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) - - - self.scratch.output_conv = nn.Sequential( - nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), - Interpolate(scale_factor=2, mode="bilinear"), - nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), - self.scratch.activation, - nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), - nn.ReLU(True) if non_negative else nn.Identity(), - nn.Identity(), - ) - - if path: - self.load(path) - - - def forward(self, x): - """Forward pass. - - Args: - x (tensor): input data (image) - - Returns: - tensor: depth - """ - if self.channels_last==True: - print("self.channels_last = ", self.channels_last) - x.contiguous(memory_format=torch.channels_last) - - - layer_1 = self.pretrained.layer1(x) - layer_2 = self.pretrained.layer2(layer_1) - layer_3 = self.pretrained.layer3(layer_2) - layer_4 = self.pretrained.layer4(layer_3) - - layer_1_rn = self.scratch.layer1_rn(layer_1) - layer_2_rn = self.scratch.layer2_rn(layer_2) - layer_3_rn = self.scratch.layer3_rn(layer_3) - layer_4_rn = self.scratch.layer4_rn(layer_4) - - - path_4 = self.scratch.refinenet4(layer_4_rn) - path_3 = self.scratch.refinenet3(path_4, layer_3_rn) - path_2 = self.scratch.refinenet2(path_3, layer_2_rn) - path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - - out = self.scratch.output_conv(path_1) - - return torch.squeeze(out, dim=1) - - - -def fuse_model(m): - prev_previous_type = nn.Identity() - prev_previous_name = '' - previous_type = nn.Identity() - previous_name = '' - for name, module in m.named_modules(): - if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: - # print("FUSED ", prev_previous_name, previous_name, name) - torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) - elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: - # print("FUSED ", prev_previous_name, previous_name) - torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) - # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: - # print("FUSED ", previous_name, name) - # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) - - prev_previous_type = previous_type - prev_previous_name = previous_name - previous_type = type(module) - previous_name = name \ No newline at end of file diff --git a/Control-Color/ldm/modules/midas/midas/transforms.py b/Control-Color/ldm/modules/midas/midas/transforms.py deleted file mode 100644 index 350cbc11662633ad7f8968eb10be2e7de6e384e9..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/midas/transforms.py +++ /dev/null @@ -1,234 +0,0 @@ -import numpy as np -import cv2 -import math - - -def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): - """Rezise the sample to ensure the given size. Keeps aspect ratio. - - Args: - sample (dict): sample - size (tuple): image size - - Returns: - tuple: new size - """ - shape = list(sample["disparity"].shape) - - if shape[0] >= size[0] and shape[1] >= size[1]: - return sample - - scale = [0, 0] - scale[0] = size[0] / shape[0] - scale[1] = size[1] / shape[1] - - scale = max(scale) - - shape[0] = math.ceil(scale * shape[0]) - shape[1] = math.ceil(scale * shape[1]) - - # resize - sample["image"] = cv2.resize( - sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method - ) - - sample["disparity"] = cv2.resize( - sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST - ) - sample["mask"] = cv2.resize( - sample["mask"].astype(np.float32), - tuple(shape[::-1]), - interpolation=cv2.INTER_NEAREST, - ) - sample["mask"] = sample["mask"].astype(bool) - - return tuple(shape) - - -class Resize(object): - """Resize sample to given size (width, height). - """ - - def __init__( - self, - width, - height, - resize_target=True, - keep_aspect_ratio=False, - ensure_multiple_of=1, - resize_method="lower_bound", - image_interpolation_method=cv2.INTER_AREA, - ): - """Init. - - Args: - width (int): desired output width - height (int): desired output height - resize_target (bool, optional): - True: Resize the full sample (image, mask, target). - False: Resize image only. - Defaults to True. - keep_aspect_ratio (bool, optional): - True: Keep the aspect ratio of the input sample. - Output sample might not have the given width and height, and - resize behaviour depends on the parameter 'resize_method'. - Defaults to False. - ensure_multiple_of (int, optional): - Output width and height is constrained to be multiple of this parameter. - Defaults to 1. - resize_method (str, optional): - "lower_bound": Output will be at least as large as the given size. - "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) - "minimal": Scale as least as possible. (Output size might be smaller than given size.) - Defaults to "lower_bound". - """ - self.__width = width - self.__height = height - - self.__resize_target = resize_target - self.__keep_aspect_ratio = keep_aspect_ratio - self.__multiple_of = ensure_multiple_of - self.__resize_method = resize_method - self.__image_interpolation_method = image_interpolation_method - - def constrain_to_multiple_of(self, x, min_val=0, max_val=None): - y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) - - if max_val is not None and y > max_val: - y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) - - if y < min_val: - y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) - - return y - - def get_size(self, width, height): - # determine new height and width - scale_height = self.__height / height - scale_width = self.__width / width - - if self.__keep_aspect_ratio: - if self.__resize_method == "lower_bound": - # scale such that output size is lower bound - if scale_width > scale_height: - # fit width - scale_height = scale_width - else: - # fit height - scale_width = scale_height - elif self.__resize_method == "upper_bound": - # scale such that output size is upper bound - if scale_width < scale_height: - # fit width - scale_height = scale_width - else: - # fit height - scale_width = scale_height - elif self.__resize_method == "minimal": - # scale as least as possbile - if abs(1 - scale_width) < abs(1 - scale_height): - # fit width - scale_height = scale_width - else: - # fit height - scale_width = scale_height - else: - raise ValueError( - f"resize_method {self.__resize_method} not implemented" - ) - - if self.__resize_method == "lower_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, min_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, min_val=self.__width - ) - elif self.__resize_method == "upper_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, max_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, max_val=self.__width - ) - elif self.__resize_method == "minimal": - new_height = self.constrain_to_multiple_of(scale_height * height) - new_width = self.constrain_to_multiple_of(scale_width * width) - else: - raise ValueError(f"resize_method {self.__resize_method} not implemented") - - return (new_width, new_height) - - def __call__(self, sample): - width, height = self.get_size( - sample["image"].shape[1], sample["image"].shape[0] - ) - - # resize sample - sample["image"] = cv2.resize( - sample["image"], - (width, height), - interpolation=self.__image_interpolation_method, - ) - - if self.__resize_target: - if "disparity" in sample: - sample["disparity"] = cv2.resize( - sample["disparity"], - (width, height), - interpolation=cv2.INTER_NEAREST, - ) - - if "depth" in sample: - sample["depth"] = cv2.resize( - sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST - ) - - sample["mask"] = cv2.resize( - sample["mask"].astype(np.float32), - (width, height), - interpolation=cv2.INTER_NEAREST, - ) - sample["mask"] = sample["mask"].astype(bool) - - return sample - - -class NormalizeImage(object): - """Normlize image by given mean and std. - """ - - def __init__(self, mean, std): - self.__mean = mean - self.__std = std - - def __call__(self, sample): - sample["image"] = (sample["image"] - self.__mean) / self.__std - - return sample - - -class PrepareForNet(object): - """Prepare sample for usage as network input. - """ - - def __init__(self): - pass - - def __call__(self, sample): - image = np.transpose(sample["image"], (2, 0, 1)) - sample["image"] = np.ascontiguousarray(image).astype(np.float32) - - if "mask" in sample: - sample["mask"] = sample["mask"].astype(np.float32) - sample["mask"] = np.ascontiguousarray(sample["mask"]) - - if "disparity" in sample: - disparity = sample["disparity"].astype(np.float32) - sample["disparity"] = np.ascontiguousarray(disparity) - - if "depth" in sample: - depth = sample["depth"].astype(np.float32) - sample["depth"] = np.ascontiguousarray(depth) - - return sample diff --git a/Control-Color/ldm/modules/midas/midas/vit.py b/Control-Color/ldm/modules/midas/midas/vit.py deleted file mode 100644 index ea46b1be88b261b0dec04f3da0256f5f66f88a74..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/midas/vit.py +++ /dev/null @@ -1,491 +0,0 @@ -import torch -import torch.nn as nn -import timm -import types -import math -import torch.nn.functional as F - - -class Slice(nn.Module): - def __init__(self, start_index=1): - super(Slice, self).__init__() - self.start_index = start_index - - def forward(self, x): - return x[:, self.start_index :] - - -class AddReadout(nn.Module): - def __init__(self, start_index=1): - super(AddReadout, self).__init__() - self.start_index = start_index - - def forward(self, x): - if self.start_index == 2: - readout = (x[:, 0] + x[:, 1]) / 2 - else: - readout = x[:, 0] - return x[:, self.start_index :] + readout.unsqueeze(1) - - -class ProjectReadout(nn.Module): - def __init__(self, in_features, start_index=1): - super(ProjectReadout, self).__init__() - self.start_index = start_index - - self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) - - def forward(self, x): - readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) - features = torch.cat((x[:, self.start_index :], readout), -1) - - return self.project(features) - - -class Transpose(nn.Module): - def __init__(self, dim0, dim1): - super(Transpose, self).__init__() - self.dim0 = dim0 - self.dim1 = dim1 - - def forward(self, x): - x = x.transpose(self.dim0, self.dim1) - return x - - -def forward_vit(pretrained, x): - b, c, h, w = x.shape - - glob = pretrained.model.forward_flex(x) - - layer_1 = pretrained.activations["1"] - layer_2 = pretrained.activations["2"] - layer_3 = pretrained.activations["3"] - layer_4 = pretrained.activations["4"] - - layer_1 = pretrained.act_postprocess1[0:2](layer_1) - layer_2 = pretrained.act_postprocess2[0:2](layer_2) - layer_3 = pretrained.act_postprocess3[0:2](layer_3) - layer_4 = pretrained.act_postprocess4[0:2](layer_4) - - unflatten = nn.Sequential( - nn.Unflatten( - 2, - torch.Size( - [ - h // pretrained.model.patch_size[1], - w // pretrained.model.patch_size[0], - ] - ), - ) - ) - - if layer_1.ndim == 3: - layer_1 = unflatten(layer_1) - if layer_2.ndim == 3: - layer_2 = unflatten(layer_2) - if layer_3.ndim == 3: - layer_3 = unflatten(layer_3) - if layer_4.ndim == 3: - layer_4 = unflatten(layer_4) - - layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) - layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) - layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) - layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) - - return layer_1, layer_2, layer_3, layer_4 - - -def _resize_pos_embed(self, posemb, gs_h, gs_w): - posemb_tok, posemb_grid = ( - posemb[:, : self.start_index], - posemb[0, self.start_index :], - ) - - gs_old = int(math.sqrt(len(posemb_grid))) - - posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") - posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) - - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) - - return posemb - - -def forward_flex(self, x): - b, c, h, w = x.shape - - pos_embed = self._resize_pos_embed( - self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] - ) - - B = x.shape[0] - - if hasattr(self.patch_embed, "backbone"): - x = self.patch_embed.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - - x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) - - if getattr(self, "dist_token", None) is not None: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks - dist_token = self.dist_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, dist_token, x), dim=1) - else: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - - x = x + pos_embed - x = self.pos_drop(x) - - for blk in self.blocks: - x = blk(x) - - x = self.norm(x) - - return x - - -activations = {} - - -def get_activation(name): - def hook(model, input, output): - activations[name] = output - - return hook - - -def get_readout_oper(vit_features, features, use_readout, start_index=1): - if use_readout == "ignore": - readout_oper = [Slice(start_index)] * len(features) - elif use_readout == "add": - readout_oper = [AddReadout(start_index)] * len(features) - elif use_readout == "project": - readout_oper = [ - ProjectReadout(vit_features, start_index) for out_feat in features - ] - else: - assert ( - False - ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" - - return readout_oper - - -def _make_vit_b16_backbone( - model, - features=[96, 192, 384, 768], - size=[384, 384], - hooks=[2, 5, 8, 11], - vit_features=768, - use_readout="ignore", - start_index=1, -): - pretrained = nn.Module() - - pretrained.model = model - pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) - pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) - pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) - pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) - - pretrained.activations = activations - - readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) - - # 32, 48, 136, 384 - pretrained.act_postprocess1 = nn.Sequential( - readout_oper[0], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[0], - kernel_size=1, - stride=1, - padding=0, - ), - nn.ConvTranspose2d( - in_channels=features[0], - out_channels=features[0], - kernel_size=4, - stride=4, - padding=0, - bias=True, - dilation=1, - groups=1, - ), - ) - - pretrained.act_postprocess2 = nn.Sequential( - readout_oper[1], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[1], - kernel_size=1, - stride=1, - padding=0, - ), - nn.ConvTranspose2d( - in_channels=features[1], - out_channels=features[1], - kernel_size=2, - stride=2, - padding=0, - bias=True, - dilation=1, - groups=1, - ), - ) - - pretrained.act_postprocess3 = nn.Sequential( - readout_oper[2], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[2], - kernel_size=1, - stride=1, - padding=0, - ), - ) - - pretrained.act_postprocess4 = nn.Sequential( - readout_oper[3], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[3], - kernel_size=1, - stride=1, - padding=0, - ), - nn.Conv2d( - in_channels=features[3], - out_channels=features[3], - kernel_size=3, - stride=2, - padding=1, - ), - ) - - pretrained.model.start_index = start_index - pretrained.model.patch_size = [16, 16] - - # We inject this function into the VisionTransformer instances so that - # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) - - return pretrained - - -def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) - - hooks = [5, 11, 17, 23] if hooks == None else hooks - return _make_vit_b16_backbone( - model, - features=[256, 512, 1024, 1024], - hooks=hooks, - vit_features=1024, - use_readout=use_readout, - ) - - -def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) - - hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) - - -def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) - - hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) - - -def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model( - "vit_deit_base_distilled_patch16_384", pretrained=pretrained - ) - - hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, - features=[96, 192, 384, 768], - hooks=hooks, - use_readout=use_readout, - start_index=2, - ) - - -def _make_vit_b_rn50_backbone( - model, - features=[256, 512, 768, 768], - size=[384, 384], - hooks=[0, 1, 8, 11], - vit_features=768, - use_vit_only=False, - use_readout="ignore", - start_index=1, -): - pretrained = nn.Module() - - pretrained.model = model - - if use_vit_only == True: - pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) - pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) - else: - pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( - get_activation("1") - ) - pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( - get_activation("2") - ) - - pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) - pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) - - pretrained.activations = activations - - readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) - - if use_vit_only == True: - pretrained.act_postprocess1 = nn.Sequential( - readout_oper[0], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[0], - kernel_size=1, - stride=1, - padding=0, - ), - nn.ConvTranspose2d( - in_channels=features[0], - out_channels=features[0], - kernel_size=4, - stride=4, - padding=0, - bias=True, - dilation=1, - groups=1, - ), - ) - - pretrained.act_postprocess2 = nn.Sequential( - readout_oper[1], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[1], - kernel_size=1, - stride=1, - padding=0, - ), - nn.ConvTranspose2d( - in_channels=features[1], - out_channels=features[1], - kernel_size=2, - stride=2, - padding=0, - bias=True, - dilation=1, - groups=1, - ), - ) - else: - pretrained.act_postprocess1 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) - pretrained.act_postprocess2 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) - - pretrained.act_postprocess3 = nn.Sequential( - readout_oper[2], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[2], - kernel_size=1, - stride=1, - padding=0, - ), - ) - - pretrained.act_postprocess4 = nn.Sequential( - readout_oper[3], - Transpose(1, 2), - nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), - nn.Conv2d( - in_channels=vit_features, - out_channels=features[3], - kernel_size=1, - stride=1, - padding=0, - ), - nn.Conv2d( - in_channels=features[3], - out_channels=features[3], - kernel_size=3, - stride=2, - padding=1, - ), - ) - - pretrained.model.start_index = start_index - pretrained.model.patch_size = [16, 16] - - # We inject this function into the VisionTransformer instances so that - # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - - # We inject this function into the VisionTransformer instances so that - # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) - - return pretrained - - -def _make_pretrained_vitb_rn50_384( - pretrained, use_readout="ignore", hooks=None, use_vit_only=False -): - model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) - - hooks = [0, 1, 8, 11] if hooks == None else hooks - return _make_vit_b_rn50_backbone( - model, - features=[256, 512, 768, 768], - size=[384, 384], - hooks=hooks, - use_vit_only=use_vit_only, - use_readout=use_readout, - ) diff --git a/Control-Color/ldm/modules/midas/utils.py b/Control-Color/ldm/modules/midas/utils.py deleted file mode 100644 index 9a9d3b5b66370fa98da9e067ba53ead848ea9a59..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/modules/midas/utils.py +++ /dev/null @@ -1,189 +0,0 @@ -"""Utils for monoDepth.""" -import sys -import re -import numpy as np -import cv2 -import torch - - -def read_pfm(path): - """Read pfm file. - - Args: - path (str): path to file - - Returns: - tuple: (data, scale) - """ - with open(path, "rb") as file: - - color = None - width = None - height = None - scale = None - endian = None - - header = file.readline().rstrip() - if header.decode("ascii") == "PF": - color = True - elif header.decode("ascii") == "Pf": - color = False - else: - raise Exception("Not a PFM file: " + path) - - dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) - if dim_match: - width, height = list(map(int, dim_match.groups())) - else: - raise Exception("Malformed PFM header.") - - scale = float(file.readline().decode("ascii").rstrip()) - if scale < 0: - # little-endian - endian = "<" - scale = -scale - else: - # big-endian - endian = ">" - - data = np.fromfile(file, endian + "f") - shape = (height, width, 3) if color else (height, width) - - data = np.reshape(data, shape) - data = np.flipud(data) - - return data, scale - - -def write_pfm(path, image, scale=1): - """Write pfm file. - - Args: - path (str): pathto file - image (array): data - scale (int, optional): Scale. Defaults to 1. - """ - - with open(path, "wb") as file: - color = None - - if image.dtype.name != "float32": - raise Exception("Image dtype must be float32.") - - image = np.flipud(image) - - if len(image.shape) == 3 and image.shape[2] == 3: # color image - color = True - elif ( - len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 - ): # greyscale - color = False - else: - raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") - - file.write("PF\n" if color else "Pf\n".encode()) - file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) - - endian = image.dtype.byteorder - - if endian == "<" or endian == "=" and sys.byteorder == "little": - scale = -scale - - file.write("%f\n".encode() % scale) - - image.tofile(file) - - -def read_image(path): - """Read image and output RGB image (0-1). - - Args: - path (str): path to file - - Returns: - array: RGB image (0-1) - """ - img = cv2.imread(path) - - if img.ndim == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 - - return img - - -def resize_image(img): - """Resize image and make it fit for network. - - Args: - img (array): image - - Returns: - tensor: data ready for network - """ - height_orig = img.shape[0] - width_orig = img.shape[1] - - if width_orig > height_orig: - scale = width_orig / 384 - else: - scale = height_orig / 384 - - height = (np.ceil(height_orig / scale / 32) * 32).astype(int) - width = (np.ceil(width_orig / scale / 32) * 32).astype(int) - - img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) - - img_resized = ( - torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() - ) - img_resized = img_resized.unsqueeze(0) - - return img_resized - - -def resize_depth(depth, width, height): - """Resize depth map and bring to CPU (numpy). - - Args: - depth (tensor): depth - width (int): image width - height (int): image height - - Returns: - array: processed depth - """ - depth = torch.squeeze(depth[0, :, :, :]).to("cpu") - - depth_resized = cv2.resize( - depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC - ) - - return depth_resized - -def write_depth(path, depth, bits=1): - """Write depth map to pfm and png file. - - Args: - path (str): filepath without extension - depth (array): depth - """ - write_pfm(path + ".pfm", depth.astype(np.float32)) - - depth_min = depth.min() - depth_max = depth.max() - - max_val = (2**(8*bits))-1 - - if depth_max - depth_min > np.finfo("float").eps: - out = max_val * (depth - depth_min) / (depth_max - depth_min) - else: - out = np.zeros(depth.shape, dtype=depth.type) - - if bits == 1: - cv2.imwrite(path + ".png", out.astype("uint8")) - elif bits == 2: - cv2.imwrite(path + ".png", out.astype("uint16")) - - return diff --git a/Control-Color/ldm/util.py b/Control-Color/ldm/util.py deleted file mode 100644 index a69857d0b7c9496a7a1228e5b68b732be93d6cf9..0000000000000000000000000000000000000000 --- a/Control-Color/ldm/util.py +++ /dev/null @@ -1,235 +0,0 @@ -import importlib - -import torch -from torch import optim -import numpy as np - -from inspect import isfunction -from PIL import Image, ImageDraw, ImageFont -from torchvision import transforms -import cv2 - -# def get_hint_image(image_withmask): -# image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2. -# image_gray=cv2.cvtColor(np.asarray(image.permute(1,2,0).cpu()),cv2.COLOR_RGB2LAB)[:,:,0] -# image_gray = torch.from_numpy(cv2.merge([image_gray,image_gray,image_gray])).permute(2,0,1) -# mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2. -# H,W=mask.shape -# for i in range(H): -# for j in range(W): -# if mask[i,j]==0: -# image[:,i,j]=image_gray[:,i,j] #torch.mean(image[:,i,j]) #image_gray[:,i,j] -# return image - -def get_hint_image(image,image_gray,mask): - # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2. - # image_gray=cv2.cvtColor(np.asarray(image.permute(1,2,0).cpu()),cv2.COLOR_RGB2LAB)[:,:,0] - # image_gray = torch.from_numpy(cv2.merge([image_gray,image_gray,image_gray])).permute(2,0,1) - # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2. - image=np.array(image.copy()) - image_gray=np.array(image_gray.copy()) - H,W=mask.shape - for i in range(H): - for j in range(W): - if mask[i,j]==0: - image[i,j]=image_gray[i,j] #torch.mean(image[:,i,j]) #image_gray[:,i,j] - return Image.fromarray(image) - -def log_txt_as_img(wh,masked_image, xc, size=10): - # wh a tuple of (width, height) - # xc a list of captions to plot - xc=xc - b = len(xc) - txts = list() - for bi in range(b): - txt = Image.new("RGB", wh, color="white") - # image=(image_withmask.squeeze(0)[:3,:,:]+1.)/2. - # mask=(image_withmask.squeeze(0)[3,:,:]+1.)/2. - # image=(image_withmask+1.)/2. - # # image = get_hint_image(image_withmask) - # # print(image.shape) - # image_target=transforms.ToPILImage()(image.squeeze(0)).convert("RGB") - # # image_gray=transforms.ToPILImage()(image).convert("L") - image=(masked_image.squeeze(0)+1.)/2. - image_target=transforms.ToPILImage()(image.squeeze(0)).convert("RGB") - txt = image_target#get_hint_image(image_target,image_gray,mask) - draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) - nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) - - try: - draw.text((0, 0), lines, fill="black", font=font) - except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") - - txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 - txts.append(txt) - txts = np.stack(txts) - txts = torch.tensor(txts) - return txts - - -def ismap(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] > 3) - - -def isimage(x): - if not isinstance(x,torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def mean_flat(tensor): - """ - https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if not config == '__is_first_stage__':#changed for only training vae - return None - # elif config == "__is_unconditional__":#changed for only training vae - # return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -class AdamWwithEMAandWings(optim.Optimizer): - # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 - def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using - weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code - ema_power=1., param_names=()): - """AdamW that saves EMA versions of the parameters.""" - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0.0 <= ema_decay <= 1.0: - raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, - ema_power=ema_power, param_names=param_names) - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - ema_params_with_grad = [] - state_sums = [] - max_exp_avg_sqs = [] - state_steps = [] - amsgrad = group['amsgrad'] - beta1, beta2 = group['betas'] - ema_decay = group['ema_decay'] - ema_power = group['ema_power'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - # Exponential moving average of parameter values - state['param_exp_avg'] = p.detach().float().clone() - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - ema_params_with_grad.append(state['param_exp_avg']) - - if amsgrad: - max_exp_avg_sqs.append(state['max_exp_avg_sq']) - - # update the steps for each param group update - state['step'] += 1 - # record the step after step update - state_steps.append(state['step']) - - optim._functional.adamw(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - maximize=False) - - cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) - for param, ema_param in zip(params_with_grad, ema_params_with_grad): - ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) - - return loss \ No newline at end of file diff --git a/Control-Color/models/cldm_v15_inpainting_infer.yaml b/Control-Color/models/cldm_v15_inpainting_infer.yaml deleted file mode 100644 index ed34bd309b7f632c44e333a2aed90666d32629e2..0000000000000000000000000000000000000000 --- a/Control-Color/models/cldm_v15_inpainting_infer.yaml +++ /dev/null @@ -1,87 +0,0 @@ -model: - target: cldm.cldm.ControlLDM - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - control_key: "hint" - masked_image: "mask_img" - mask: "mask" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - only_mid_control: False - load_loss: False - - control_stage_config: - target: cldm.cldm.ControlNet - params: - image_size: 32 # unused - in_channels: 4 - hint_channels: 3 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - unet_config: - target: cldm.cldm.ControlledUnetModel - params: - image_size: 32 # unused - in_channels: 9 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - contextual_stage_config: - target: models_deep_exp.NonlocalNet.VGG19_pytorch - - cond_stage_config: - # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - target: ldm.modules.encoders.modules.FrozenCLIPDualEmbedder - #ldm.modules.encoders.modules.FrozenCLIPDualEmbedder diff --git a/Control-Color/models/cldm_v15_inpainting_infer1.yaml b/Control-Color/models/cldm_v15_inpainting_infer1.yaml deleted file mode 100644 index 52bc556a3a0797dd1e712c9a586151c7c16a79d6..0000000000000000000000000000000000000000 --- a/Control-Color/models/cldm_v15_inpainting_infer1.yaml +++ /dev/null @@ -1,87 +0,0 @@ -model: - target: cldm.cldm.ControlLDM - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - control_key: "hint" - masked_image: "mask_img" - mask: "mask" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - only_mid_control: False - load_loss: False - - control_stage_config: - target: cldm.cldm.ControlNet - params: - image_size: 32 # unused - in_channels: 4 - hint_channels: 3 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - unet_config: - target: cldm.cldm.ControlledUnetModel - params: - image_size: 32 # unused - in_channels: 9 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - contextual_stage_config: - target: models_deep_exp.NonlocalNet.VGG19_pytorch - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - # target: ldm.modules.encoders.modules.FrozenCLIPDualEmbedder - #ldm.modules.encoders.modules.FrozenCLIPDualEmbedder diff --git a/Control-Color/requirements.txt b/Control-Color/requirements.txt deleted file mode 100644 index 6d3d6f2111c8034831ae399558bd89d6ee1639ed..0000000000000000000000000000000000000000 --- a/Control-Color/requirements.txt +++ /dev/null @@ -1,29 +0,0 @@ -gradio -gradio-client -albumentations==1.3.0 -opencv-python==4.9.0.80 -opencv-python-headless==4.5.5.64 -imageio==2.9.0 -imageio-ffmpeg==0.4.2 -pytorch-lightning==1.5.0 -omegaconf==2.1.1 -test-tube>=0.7.5 -streamlit==1.12.1 -webdataset==0.2.5 -kornia==0.6 -open_clip_torch==2.0.2 -invisible-watermark>=0.1.5 -streamlit-drawable-canvas==0.8.0 -torchmetrics==0.6.0 -addict==2.4.0 -yapf==0.32.0 -prettytable==3.6.0 -basicsr==1.4.2 -salesforce-lavis==1.0.2 -grpcio==1.60 -pydantic==1.10.5 -wandb==0.15.12 -spacy==3.5.1 -typer==0.7.0 -typing-extensions==4.4.0 -fastapi==0.92.0 \ No newline at end of file diff --git a/Control-Color/share.py b/Control-Color/share.py deleted file mode 100644 index 463af08fb936d650b5dd2e66183661181c34a3d6..0000000000000000000000000000000000000000 --- a/Control-Color/share.py +++ /dev/null @@ -1,8 +0,0 @@ -import config -from cldm.hack import disable_verbosity, enable_sliced_attention - - -disable_verbosity() - -if config.save_memory: - enable_sliced_attention() diff --git a/Control-Color/taming/__pycache__/util.cpython-38.pyc b/Control-Color/taming/__pycache__/util.cpython-38.pyc deleted file mode 100644 index 8303bb601f2c59bde20062a3d6ea53e16d6f5ca7..0000000000000000000000000000000000000000 Binary files a/Control-Color/taming/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/taming/data/ade20k.py b/Control-Color/taming/data/ade20k.py deleted file mode 100644 index 366dae97207dbb8356598d636e14ad084d45bc76..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/ade20k.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import numpy as np -import cv2 -import albumentations -from PIL import Image -from torch.utils.data import Dataset - -from taming.data.sflckr import SegmentationBase # for examples included in repo - - -class Examples(SegmentationBase): - def __init__(self, size=256, random_crop=False, interpolation="bicubic"): - super().__init__(data_csv="data/ade20k_examples.txt", - data_root="data/ade20k_images", - segmentation_root="data/ade20k_segmentations", - size=size, random_crop=random_crop, - interpolation=interpolation, - n_labels=151, shift_segmentation=False) - - -# With semantic map and scene label -class ADE20kBase(Dataset): - def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): - self.split = self.get_split() - self.n_labels = 151 # unknown + 150 - self.data_csv = {"train": "data/ade20k_train.txt", - "validation": "data/ade20k_test.txt"}[self.split] - self.data_root = "data/ade20k_root" - with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: - self.scene_categories = f.read().splitlines() - self.scene_categories = dict(line.split() for line in self.scene_categories) - with open(self.data_csv, "r") as f: - self.image_paths = f.read().splitlines() - self._length = len(self.image_paths) - self.labels = { - "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, "images", l) - for l in self.image_paths], - "relative_segmentation_path_": [l.replace(".jpg", ".png") - for l in self.image_paths], - "segmentation_path_": [os.path.join(self.data_root, "annotations", - l.replace(".jpg", ".png")) - for l in self.image_paths], - "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] - for l in self.image_paths], - } - - size = None if size is not None and size<=0 else size - self.size = size - if crop_size is None: - self.crop_size = size if size is not None else None - else: - self.crop_size = crop_size - if self.size is not None: - self.interpolation = interpolation - self.interpolation = { - "nearest": cv2.INTER_NEAREST, - "bilinear": cv2.INTER_LINEAR, - "bicubic": cv2.INTER_CUBIC, - "area": cv2.INTER_AREA, - "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] - self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=self.interpolation) - self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=cv2.INTER_NEAREST) - - if crop_size is not None: - self.center_crop = not random_crop - if self.center_crop: - self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) - else: - self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) - self.preprocessor = self.cropper - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = dict((k, self.labels[k][i]) for k in self.labels) - image = Image.open(example["file_path_"]) - if not image.mode == "RGB": - image = image.convert("RGB") - image = np.array(image).astype(np.uint8) - if self.size is not None: - image = self.image_rescaler(image=image)["image"] - segmentation = Image.open(example["segmentation_path_"]) - segmentation = np.array(segmentation).astype(np.uint8) - if self.size is not None: - segmentation = self.segmentation_rescaler(image=segmentation)["image"] - if self.size is not None: - processed = self.preprocessor(image=image, mask=segmentation) - else: - processed = {"image": image, "mask": segmentation} - example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) - segmentation = processed["mask"] - onehot = np.eye(self.n_labels)[segmentation] - example["segmentation"] = onehot - return example - - -class ADE20kTrain(ADE20kBase): - # default to random_crop=True - def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): - super().__init__(config=config, size=size, random_crop=random_crop, - interpolation=interpolation, crop_size=crop_size) - - def get_split(self): - return "train" - - -class ADE20kValidation(ADE20kBase): - def get_split(self): - return "validation" - - -if __name__ == "__main__": - dset = ADE20kValidation() - ex = dset[0] - for k in ["image", "scene_category", "segmentation"]: - print(type(ex[k])) - try: - print(ex[k].shape) - except: - print(ex[k]) diff --git a/Control-Color/taming/data/annotated_objects_coco.py b/Control-Color/taming/data/annotated_objects_coco.py deleted file mode 100644 index af000ecd943d7b8a85d7eb70195c9ecd10ab5edc..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/annotated_objects_coco.py +++ /dev/null @@ -1,139 +0,0 @@ -import json -from itertools import chain -from pathlib import Path -from typing import Iterable, Dict, List, Callable, Any -from collections import defaultdict - -from tqdm import tqdm - -from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset -from taming.data.helper_types import Annotation, ImageDescription, Category - -COCO_PATH_STRUCTURE = { - 'train': { - 'top_level': '', - 'instances_annotations': 'annotations/instances_train2017.json', - 'stuff_annotations': 'annotations/stuff_train2017.json', - 'files': 'train2017' - }, - 'validation': { - 'top_level': '', - 'instances_annotations': 'annotations/instances_val2017.json', - 'stuff_annotations': 'annotations/stuff_val2017.json', - 'files': 'val2017' - } -} - - -def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: - return { - str(img['id']): ImageDescription( - id=img['id'], - license=img.get('license'), - file_name=img['file_name'], - coco_url=img['coco_url'], - original_size=(img['width'], img['height']), - date_captured=img.get('date_captured'), - flickr_url=img.get('flickr_url') - ) - for img in description_json - } - - -def load_categories(category_json: Iterable) -> Dict[str, Category]: - return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) - for cat in category_json if cat['name'] != 'other'} - - -def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], - category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: - annotations = defaultdict(list) - total = sum(len(a) for a in annotations_json) - for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): - image_id = str(ann['image_id']) - if image_id not in image_descriptions: - raise ValueError(f'image_id [{image_id}] has no image description.') - category_id = ann['category_id'] - try: - category_no = category_no_for_id(str(category_id)) - except KeyError: - continue - - width, height = image_descriptions[image_id].original_size - bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) - - annotations[image_id].append( - Annotation( - id=ann['id'], - area=bbox[2]*bbox[3], # use bbox area - is_group_of=ann['iscrowd'], - image_id=ann['image_id'], - bbox=bbox, - category_id=str(category_id), - category_no=category_no - ) - ) - return dict(annotations) - - -class AnnotatedObjectsCoco(AnnotatedObjectsDataset): - def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): - """ - @param data_path: is the path to the following folder structure: - coco/ - ├── annotations - │ ├── instances_train2017.json - │ ├── instances_val2017.json - │ ├── stuff_train2017.json - │ └── stuff_val2017.json - ├── train2017 - │ ├── 000000000009.jpg - │ ├── 000000000025.jpg - │ └── ... - ├── val2017 - │ ├── 000000000139.jpg - │ ├── 000000000285.jpg - │ └── ... - @param: split: one of 'train' or 'validation' - @param: desired image size (give square images) - """ - super().__init__(**kwargs) - self.use_things = use_things - self.use_stuff = use_stuff - - with open(self.paths['instances_annotations']) as f: - inst_data_json = json.load(f) - with open(self.paths['stuff_annotations']) as f: - stuff_data_json = json.load(f) - - category_jsons = [] - annotation_jsons = [] - if self.use_things: - category_jsons.append(inst_data_json['categories']) - annotation_jsons.append(inst_data_json['annotations']) - if self.use_stuff: - category_jsons.append(stuff_data_json['categories']) - annotation_jsons.append(stuff_data_json['annotations']) - - self.categories = load_categories(chain(*category_jsons)) - self.filter_categories() - self.setup_category_id_and_number() - - self.image_descriptions = load_image_descriptions(inst_data_json['images']) - annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) - self.annotations = self.filter_object_number(annotations, self.min_object_area, - self.min_objects_per_image, self.max_objects_per_image) - self.image_ids = list(self.annotations.keys()) - self.clean_up_annotations_and_image_descriptions() - - def get_path_structure(self) -> Dict[str, str]: - if self.split not in COCO_PATH_STRUCTURE: - raise ValueError(f'Split [{self.split} does not exist for COCO data.]') - return COCO_PATH_STRUCTURE[self.split] - - def get_image_path(self, image_id: str) -> Path: - return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) - - def get_image_description(self, image_id: str) -> Dict[str, Any]: - # noinspection PyProtectedMember - return self.image_descriptions[image_id]._asdict() diff --git a/Control-Color/taming/data/annotated_objects_dataset.py b/Control-Color/taming/data/annotated_objects_dataset.py deleted file mode 100644 index 53cc346a1c76289a4964d7dc8a29582172f33dc0..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/annotated_objects_dataset.py +++ /dev/null @@ -1,218 +0,0 @@ -from pathlib import Path -from typing import Optional, List, Callable, Dict, Any, Union -import warnings - -import PIL.Image as pil_image -from torch import Tensor -from torch.utils.data import Dataset -from torchvision import transforms - -from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder -from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder -from taming.data.conditional_builder.utils import load_object_from_string -from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType -from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ - Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor - - -class AnnotatedObjectsDataset(Dataset): - def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int, - min_object_area: float, min_objects_per_image: int, max_objects_per_image: int, - crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool, - encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "", - no_object_classes: Optional[int] = None): - self.data_path = data_path - self.split = split - self.keys = keys - self.target_image_size = target_image_size - self.min_object_area = min_object_area - self.min_objects_per_image = min_objects_per_image - self.max_objects_per_image = max_objects_per_image - self.crop_method = crop_method - self.random_flip = random_flip - self.no_tokens = no_tokens - self.use_group_parameter = use_group_parameter - self.encode_crop = encode_crop - - self.annotations = None - self.image_descriptions = None - self.categories = None - self.category_ids = None - self.category_number = None - self.image_ids = None - self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip) - self.paths = self.build_paths(self.data_path) - self._conditional_builders = None - self.category_allow_list = None - if category_allow_list_target: - allow_list = load_object_from_string(category_allow_list_target) - self.category_allow_list = {name for name, _ in allow_list} - self.category_mapping = {} - if category_mapping_target: - self.category_mapping = load_object_from_string(category_mapping_target) - self.no_object_classes = no_object_classes - - def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]: - top_level = Path(top_level) - sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()} - for path in sub_paths.values(): - if not path.exists(): - raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.') - return sub_paths - - @staticmethod - def load_image_from_disk(path: Path) -> Image: - return pil_image.open(path).convert('RGB') - - @staticmethod - def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool): - transform_functions = [] - if crop_method == 'none': - transform_functions.append(transforms.Resize((target_image_size, target_image_size))) - elif crop_method == 'center': - transform_functions.extend([ - transforms.Resize(target_image_size), - CenterCropReturnCoordinates(target_image_size) - ]) - elif crop_method == 'random-1d': - transform_functions.extend([ - transforms.Resize(target_image_size), - RandomCrop1dReturnCoordinates(target_image_size) - ]) - elif crop_method == 'random-2d': - transform_functions.extend([ - Random2dCropReturnCoordinates(target_image_size), - transforms.Resize(target_image_size) - ]) - elif crop_method is None: - return None - else: - raise ValueError(f'Received invalid crop method [{crop_method}].') - if random_flip: - transform_functions.append(RandomHorizontalFlipReturn()) - transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.)) - return transform_functions - - def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor): - crop_bbox = None - flipped = None - for t in self.transform_functions: - if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)): - crop_bbox, x = t(x) - elif isinstance(t, RandomHorizontalFlipReturn): - flipped, x = t(x) - else: - x = t(x) - return crop_bbox, flipped, x - - @property - def no_classes(self) -> int: - return self.no_object_classes if self.no_object_classes else len(self.categories) - - @property - def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: - # cannot set this up in init because no_classes is only known after loading data in init of superclass - if self._conditional_builders is None: - self._conditional_builders = { - 'objects_center_points': ObjectsCenterPointsConditionalBuilder( - self.no_classes, - self.max_objects_per_image, - self.no_tokens, - self.encode_crop, - self.use_group_parameter, - getattr(self, 'use_additional_parameters', False) - ), - 'objects_bbox': ObjectsBoundingBoxConditionalBuilder( - self.no_classes, - self.max_objects_per_image, - self.no_tokens, - self.encode_crop, - self.use_group_parameter, - getattr(self, 'use_additional_parameters', False) - ) - } - return self._conditional_builders - - def filter_categories(self) -> None: - if self.category_allow_list: - self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list} - if self.category_mapping: - self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping} - - def setup_category_id_and_number(self) -> None: - self.category_ids = list(self.categories.keys()) - self.category_ids.sort() - if '/m/01s55n' in self.category_ids: - self.category_ids.remove('/m/01s55n') - self.category_ids.append('/m/01s55n') - self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)} - if self.category_allow_list is not None and self.category_mapping is None \ - and len(self.category_ids) != len(self.category_allow_list): - warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. ' - 'Make sure all names in category_allow_list exist.') - - def clean_up_annotations_and_image_descriptions(self) -> None: - image_id_set = set(self.image_ids) - self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set} - self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set} - - @staticmethod - def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float, - min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]: - filtered = {} - for image_id, annotations in all_annotations.items(): - annotations_with_min_area = [a for a in annotations if a.area > min_object_area] - if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image: - filtered[image_id] = annotations_with_min_area - return filtered - - def __len__(self): - return len(self.image_ids) - - def __getitem__(self, n: int) -> Dict[str, Any]: - image_id = self.get_image_id(n) - sample = self.get_image_description(image_id) - sample['annotations'] = self.get_annotation(image_id) - - if 'image' in self.keys: - sample['image_path'] = str(self.get_image_path(image_id)) - sample['image'] = self.load_image_from_disk(sample['image_path']) - sample['image'] = convert_pil_to_tensor(sample['image']) - sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image']) - sample['image'] = sample['image'].permute(1, 2, 0) - - for conditional, builder in self.conditional_builders.items(): - if conditional in self.keys: - sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped']) - - if self.keys: - # only return specified keys - sample = {key: sample[key] for key in self.keys} - return sample - - def get_image_id(self, no: int) -> str: - return self.image_ids[no] - - def get_annotation(self, image_id: str) -> str: - return self.annotations[image_id] - - def get_textual_label_for_category_id(self, category_id: str) -> str: - return self.categories[category_id].name - - def get_textual_label_for_category_no(self, category_no: int) -> str: - return self.categories[self.get_category_id(category_no)].name - - def get_category_number(self, category_id: str) -> int: - return self.category_number[category_id] - - def get_category_id(self, category_no: int) -> str: - return self.category_ids[category_no] - - def get_image_description(self, image_id: str) -> Dict[str, Any]: - raise NotImplementedError() - - def get_path_structure(self): - raise NotImplementedError - - def get_image_path(self, image_id: str) -> Path: - raise NotImplementedError diff --git a/Control-Color/taming/data/annotated_objects_open_images.py b/Control-Color/taming/data/annotated_objects_open_images.py deleted file mode 100644 index aede6803d2cef7a74ca784e7907d35fba6c71239..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/annotated_objects_open_images.py +++ /dev/null @@ -1,137 +0,0 @@ -from collections import defaultdict -from csv import DictReader, reader as TupleReader -from pathlib import Path -from typing import Dict, List, Any -import warnings - -from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset -from taming.data.helper_types import Annotation, Category -from tqdm import tqdm - -OPEN_IMAGES_STRUCTURE = { - 'train': { - 'top_level': '', - 'class_descriptions': 'class-descriptions-boxable.csv', - 'annotations': 'oidv6-train-annotations-bbox.csv', - 'file_list': 'train-images-boxable.csv', - 'files': 'train' - }, - 'validation': { - 'top_level': '', - 'class_descriptions': 'class-descriptions-boxable.csv', - 'annotations': 'validation-annotations-bbox.csv', - 'file_list': 'validation-images.csv', - 'files': 'validation' - }, - 'test': { - 'top_level': '', - 'class_descriptions': 'class-descriptions-boxable.csv', - 'annotations': 'test-annotations-bbox.csv', - 'file_list': 'test-images.csv', - 'files': 'test' - } -} - - -def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], - category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: - annotations: Dict[str, List[Annotation]] = defaultdict(list) - with open(descriptor_path) as file: - reader = DictReader(file) - for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): - width = float(row['XMax']) - float(row['XMin']) - height = float(row['YMax']) - float(row['YMin']) - area = width * height - category_id = row['LabelName'] - if category_id in category_mapping: - category_id = category_mapping[category_id] - if area >= min_object_area and category_id in category_no_for_id: - annotations[row['ImageID']].append( - Annotation( - id=i, - image_id=row['ImageID'], - source=row['Source'], - category_id=category_id, - category_no=category_no_for_id[category_id], - confidence=float(row['Confidence']), - bbox=(float(row['XMin']), float(row['YMin']), width, height), - area=area, - is_occluded=bool(int(row['IsOccluded'])), - is_truncated=bool(int(row['IsTruncated'])), - is_group_of=bool(int(row['IsGroupOf'])), - is_depiction=bool(int(row['IsDepiction'])), - is_inside=bool(int(row['IsInside'])) - ) - ) - if 'train' in str(descriptor_path) and i < 14000000: - warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') - return dict(annotations) - - -def load_image_ids(csv_path: Path) -> List[str]: - with open(csv_path) as file: - reader = DictReader(file) - return [row['image_name'] for row in reader] - - -def load_categories(csv_path: Path) -> Dict[str, Category]: - with open(csv_path) as file: - reader = TupleReader(file) - return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} - - -class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): - def __init__(self, use_additional_parameters: bool, **kwargs): - """ - @param data_path: is the path to the following folder structure: - open_images/ - │ oidv6-train-annotations-bbox.csv - ├── class-descriptions-boxable.csv - ├── oidv6-train-annotations-bbox.csv - ├── test - │ ├── 000026e7ee790996.jpg - │ ├── 000062a39995e348.jpg - │ └── ... - ├── test-annotations-bbox.csv - ├── test-images.csv - ├── train - │ ├── 000002b66c9c498e.jpg - │ ├── 000002b97e5471a0.jpg - │ └── ... - ├── train-images-boxable.csv - ├── validation - │ ├── 0001eeaf4aed83f9.jpg - │ ├── 0004886b7d043cfd.jpg - │ └── ... - ├── validation-annotations-bbox.csv - └── validation-images.csv - @param: split: one of 'train', 'validation' or 'test' - @param: desired image size (returns square images) - """ - - super().__init__(**kwargs) - self.use_additional_parameters = use_additional_parameters - - self.categories = load_categories(self.paths['class_descriptions']) - self.filter_categories() - self.setup_category_id_and_number() - - self.image_descriptions = {} - annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, - self.category_number) - self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, - self.max_objects_per_image) - self.image_ids = list(self.annotations.keys()) - self.clean_up_annotations_and_image_descriptions() - - def get_path_structure(self) -> Dict[str, str]: - if self.split not in OPEN_IMAGES_STRUCTURE: - raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') - return OPEN_IMAGES_STRUCTURE[self.split] - - def get_image_path(self, image_id: str) -> Path: - return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') - - def get_image_description(self, image_id: str) -> Dict[str, Any]: - image_path = self.get_image_path(image_id) - return {'file_path': str(image_path), 'file_name': image_path.name} diff --git a/Control-Color/taming/data/base.py b/Control-Color/taming/data/base.py deleted file mode 100644 index e21667df4ce4baa6bb6aad9f8679bd756e2ffdb7..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/base.py +++ /dev/null @@ -1,70 +0,0 @@ -import bisect -import numpy as np -import albumentations -from PIL import Image -from torch.utils.data import Dataset, ConcatDataset - - -class ConcatDatasetWithIndex(ConcatDataset): - """Modified from original pytorch code to return dataset idx""" - def __getitem__(self, idx): - if idx < 0: - if -idx > len(self): - raise ValueError("absolute value of index should not exceed dataset length") - idx = len(self) + idx - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - if dataset_idx == 0: - sample_idx = idx - else: - sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] - return self.datasets[dataset_idx][sample_idx], dataset_idx - - -class ImagePaths(Dataset): - def __init__(self, paths, size=None, random_crop=False, labels=None): - self.size = size - self.random_crop = random_crop - - self.labels = dict() if labels is None else labels - self.labels["file_path_"] = paths - self._length = len(paths) - - if self.size is not None and self.size > 0: - self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) - if not self.random_crop: - self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) - else: - self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) - self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) - else: - self.preprocessor = lambda **kwargs: kwargs - - def __len__(self): - return self._length - - def preprocess_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - image = np.array(image).astype(np.uint8) - image = self.preprocessor(image=image)["image"] - image = (image/127.5 - 1.0).astype(np.float32) - return image - - def __getitem__(self, i): - example = dict() - example["image"] = self.preprocess_image(self.labels["file_path_"][i]) - for k in self.labels: - example[k] = self.labels[k][i] - return example - - -class NumpyPaths(ImagePaths): - def preprocess_image(self, image_path): - image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 - image = np.transpose(image, (1,2,0)) - image = Image.fromarray(image, mode="RGB") - image = np.array(image).astype(np.uint8) - image = self.preprocessor(image=image)["image"] - image = (image/127.5 - 1.0).astype(np.float32) - return image diff --git a/Control-Color/taming/data/coco.py b/Control-Color/taming/data/coco.py deleted file mode 100644 index 2b2f7838448cb63dcf96daffe9470d58566d975a..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/coco.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -import json -import albumentations -import numpy as np -from PIL import Image -from tqdm import tqdm -from torch.utils.data import Dataset - -from taming.data.sflckr import SegmentationBase # for examples included in repo - - -class Examples(SegmentationBase): - def __init__(self, size=256, random_crop=False, interpolation="bicubic"): - super().__init__(data_csv="data/coco_examples.txt", - data_root="data/coco_images", - segmentation_root="data/coco_segmentations", - size=size, random_crop=random_crop, - interpolation=interpolation, - n_labels=183, shift_segmentation=True) - - -class CocoBase(Dataset): - """needed for (image, caption, segmentation) pairs""" - def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, - crop_size=None, force_no_crop=False, given_files=None): - self.split = self.get_split() - self.size = size - if crop_size is None: - self.crop_size = size - else: - self.crop_size = crop_size - - self.onehot = onehot_segmentation # return segmentation as rgb or one hot - self.stuffthing = use_stuffthing # include thing in segmentation - if self.onehot and not self.stuffthing: - raise NotImplemented("One hot mode is only supported for the " - "stuffthings version because labels are stored " - "a bit different.") - - data_json = datajson - with open(data_json) as json_file: - self.json_data = json.load(json_file) - self.img_id_to_captions = dict() - self.img_id_to_filepath = dict() - self.img_id_to_segmentation_filepath = dict() - - assert data_json.split("/")[-1] in ["captions_train2017.json", - "captions_val2017.json"] - if self.stuffthing: - self.segmentation_prefix = ( - "data/cocostuffthings/val2017" if - data_json.endswith("captions_val2017.json") else - "data/cocostuffthings/train2017") - else: - self.segmentation_prefix = ( - "data/coco/annotations/stuff_val2017_pixelmaps" if - data_json.endswith("captions_val2017.json") else - "data/coco/annotations/stuff_train2017_pixelmaps") - - imagedirs = self.json_data["images"] - self.labels = {"image_ids": list()} - for imgdir in tqdm(imagedirs, desc="ImgToPath"): - self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) - self.img_id_to_captions[imgdir["id"]] = list() - pngfilename = imgdir["file_name"].replace("jpg", "png") - self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( - self.segmentation_prefix, pngfilename) - if given_files is not None: - if pngfilename in given_files: - self.labels["image_ids"].append(imgdir["id"]) - else: - self.labels["image_ids"].append(imgdir["id"]) - - capdirs = self.json_data["annotations"] - for capdir in tqdm(capdirs, desc="ImgToCaptions"): - # there are in average 5 captions per image - self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) - - self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) - if self.split=="validation": - self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) - else: - self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) - self.preprocessor = albumentations.Compose( - [self.rescaler, self.cropper], - additional_targets={"segmentation": "image"}) - if force_no_crop: - self.rescaler = albumentations.Resize(height=self.size, width=self.size) - self.preprocessor = albumentations.Compose( - [self.rescaler], - additional_targets={"segmentation": "image"}) - - def __len__(self): - return len(self.labels["image_ids"]) - - def preprocess_image(self, image_path, segmentation_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - image = np.array(image).astype(np.uint8) - - segmentation = Image.open(segmentation_path) - if not self.onehot and not segmentation.mode == "RGB": - segmentation = segmentation.convert("RGB") - segmentation = np.array(segmentation).astype(np.uint8) - if self.onehot: - assert self.stuffthing - # stored in caffe format: unlabeled==255. stuff and thing from - # 0-181. to be compatible with the labels in - # https://github.com/nightrome/cocostuff/blob/master/labels.txt - # we shift stuffthing one to the right and put unlabeled in zero - # as long as segmentation is uint8 shifting to right handles the - # latter too - assert segmentation.dtype == np.uint8 - segmentation = segmentation + 1 - - processed = self.preprocessor(image=image, segmentation=segmentation) - image, segmentation = processed["image"], processed["segmentation"] - image = (image / 127.5 - 1.0).astype(np.float32) - - if self.onehot: - assert segmentation.dtype == np.uint8 - # make it one hot - n_labels = 183 - flatseg = np.ravel(segmentation) - onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) - onehot[np.arange(flatseg.size), flatseg] = True - onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) - segmentation = onehot - else: - segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) - return image, segmentation - - def __getitem__(self, i): - img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] - seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] - image, segmentation = self.preprocess_image(img_path, seg_path) - captions = self.img_id_to_captions[self.labels["image_ids"][i]] - # randomly draw one of all available captions per image - caption = captions[np.random.randint(0, len(captions))] - example = {"image": image, - "caption": [str(caption[0])], - "segmentation": segmentation, - "img_path": img_path, - "seg_path": seg_path, - "filename_": img_path.split(os.sep)[-1] - } - return example - - -class CocoImagesAndCaptionsTrain(CocoBase): - """returns a pair of (image, caption)""" - def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): - super().__init__(size=size, - dataroot="data/coco/train2017", - datajson="data/coco/annotations/captions_train2017.json", - onehot_segmentation=onehot_segmentation, - use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) - - def get_split(self): - return "train" - - -class CocoImagesAndCaptionsValidation(CocoBase): - """returns a pair of (image, caption)""" - def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, - given_files=None): - super().__init__(size=size, - dataroot="data/coco/val2017", - datajson="data/coco/annotations/captions_val2017.json", - onehot_segmentation=onehot_segmentation, - use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, - given_files=given_files) - - def get_split(self): - return "validation" diff --git a/Control-Color/taming/data/conditional_builder/objects_bbox.py b/Control-Color/taming/data/conditional_builder/objects_bbox.py deleted file mode 100644 index 15881e76b7ab2a914df8f2dfe08ae4f0c6c511b5..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/conditional_builder/objects_bbox.py +++ /dev/null @@ -1,60 +0,0 @@ -from itertools import cycle -from typing import List, Tuple, Callable, Optional - -from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont -from more_itertools.recipes import grouper -from taming.data.image_transforms import convert_pil_to_tensor -from torch import LongTensor, Tensor - -from taming.data.helper_types import BoundingBox, Annotation -from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder -from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ - pad_list, get_plot_font_size, absolute_bbox - - -class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): - @property - def object_descriptor_length(self) -> int: - return 3 - - def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: - object_triples = [ - (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) - for ann in annotations - ] - empty_triple = (self.none, self.none, self.none) - object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) - return object_triples - - def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: - conditional_list = conditional.tolist() - crop_coordinates = None - if self.encode_crop: - crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) - conditional_list = conditional_list[:-2] - object_triples = grouper(conditional_list, 3) - assert conditional.shape[0] == self.embedding_dim - return [ - (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) - for object_triple in object_triples if object_triple[0] != self.none - ], crop_coordinates - - def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], - line_width: int = 3, font_size: Optional[int] = None) -> Tensor: - plot = pil_image.new('RGB', figure_size, WHITE) - draw = pil_img_draw.Draw(plot) - font = ImageFont.truetype( - "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", - size=get_plot_font_size(font_size, figure_size) - ) - width, height = plot.size - description, crop_coordinates = self.inverse_build(conditional) - for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): - annotation = self.representation_to_annotation(representation) - class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) - bbox = absolute_bbox(bbox, width, height) - draw.rectangle(bbox, outline=color, width=line_width) - draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) - if crop_coordinates is not None: - draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) - return convert_pil_to_tensor(plot) / 127.5 - 1. diff --git a/Control-Color/taming/data/conditional_builder/objects_center_points.py b/Control-Color/taming/data/conditional_builder/objects_center_points.py deleted file mode 100644 index 9a480329cc47fb38a7b8729d424e092b77d40749..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/conditional_builder/objects_center_points.py +++ /dev/null @@ -1,168 +0,0 @@ -import math -import random -import warnings -from itertools import cycle -from typing import List, Optional, Tuple, Callable - -from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont -from more_itertools.recipes import grouper -from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ - additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ - absolute_bbox, rescale_annotations -from taming.data.helper_types import BoundingBox, Annotation -from taming.data.image_transforms import convert_pil_to_tensor -from torch import LongTensor, Tensor - - -class ObjectsCenterPointsConditionalBuilder: - def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, - use_group_parameter: bool, use_additional_parameters: bool): - self.no_object_classes = no_object_classes - self.no_max_objects = no_max_objects - self.no_tokens = no_tokens - self.encode_crop = encode_crop - self.no_sections = int(math.sqrt(self.no_tokens)) - self.use_group_parameter = use_group_parameter - self.use_additional_parameters = use_additional_parameters - - @property - def none(self) -> int: - return self.no_tokens - 1 - - @property - def object_descriptor_length(self) -> int: - return 2 - - @property - def embedding_dim(self) -> int: - extra_length = 2 if self.encode_crop else 0 - return self.no_max_objects * self.object_descriptor_length + extra_length - - def tokenize_coordinates(self, x: float, y: float) -> int: - """ - Express 2d coordinates with one number. - Example: assume self.no_tokens = 16, then no_sections = 4: - 0 0 0 0 - 0 0 # 0 - 0 0 0 0 - 0 0 0 x - Then the # position corresponds to token 6, the x position to token 15. - @param x: float in [0, 1] - @param y: float in [0, 1] - @return: discrete tokenized coordinate - """ - x_discrete = int(round(x * (self.no_sections - 1))) - y_discrete = int(round(y * (self.no_sections - 1))) - return y_discrete * self.no_sections + x_discrete - - def coordinates_from_token(self, token: int) -> (float, float): - x = token % self.no_sections - y = token // self.no_sections - return x / (self.no_sections - 1), y / (self.no_sections - 1) - - def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: - x0, y0 = self.coordinates_from_token(token1) - x1, y1 = self.coordinates_from_token(token2) - return x0, y0, x1 - x0, y1 - y0 - - def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: - return self.tokenize_coordinates(bbox[0], bbox[1]), \ - self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) - - def inverse_build(self, conditional: LongTensor) \ - -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: - conditional_list = conditional.tolist() - crop_coordinates = None - if self.encode_crop: - crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) - conditional_list = conditional_list[:-2] - table_of_content = grouper(conditional_list, self.object_descriptor_length) - assert conditional.shape[0] == self.embedding_dim - return [ - (object_tuple[0], self.coordinates_from_token(object_tuple[1])) - for object_tuple in table_of_content if object_tuple[0] != self.none - ], crop_coordinates - - def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], - line_width: int = 3, font_size: Optional[int] = None) -> Tensor: - plot = pil_image.new('RGB', figure_size, WHITE) - draw = pil_img_draw.Draw(plot) - circle_size = get_circle_size(figure_size) - font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', - size=get_plot_font_size(font_size, figure_size)) - width, height = plot.size - description, crop_coordinates = self.inverse_build(conditional) - for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): - x_abs, y_abs = x * width, y * height - ann = self.representation_to_annotation(representation) - label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) - ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] - draw.ellipse(ellipse_bbox, fill=color, width=0) - draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) - if crop_coordinates is not None: - draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) - return convert_pil_to_tensor(plot) / 127.5 - 1. - - def object_representation(self, annotation: Annotation) -> int: - modifier = 0 - if self.use_group_parameter: - modifier |= 1 * (annotation.is_group_of is True) - if self.use_additional_parameters: - modifier |= 2 * (annotation.is_occluded is True) - modifier |= 4 * (annotation.is_depiction is True) - modifier |= 8 * (annotation.is_inside is True) - return annotation.category_no + self.no_object_classes * modifier - - def representation_to_annotation(self, representation: int) -> Annotation: - category_no = representation % self.no_object_classes - modifier = representation // self.no_object_classes - # noinspection PyTypeChecker - return Annotation( - area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, - category_no=category_no, - is_group_of=bool((modifier & 1) * self.use_group_parameter), - is_occluded=bool((modifier & 2) * self.use_additional_parameters), - is_depiction=bool((modifier & 4) * self.use_additional_parameters), - is_inside=bool((modifier & 8) * self.use_additional_parameters) - ) - - def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: - return list(self.token_pair_from_bbox(crop_coordinates)) - - def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: - object_tuples = [ - (self.object_representation(a), - self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) - for a in annotations - ] - empty_tuple = (self.none, self.none) - object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) - return object_tuples - - def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ - -> LongTensor: - if len(annotations) == 0: - warnings.warn('Did not receive any annotations.') - if len(annotations) > self.no_max_objects: - warnings.warn('Received more annotations than allowed.') - annotations = annotations[:self.no_max_objects] - - if not crop_coordinates: - crop_coordinates = FULL_CROP - - random.shuffle(annotations) - annotations = filter_annotations(annotations, crop_coordinates) - if self.encode_crop: - annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) - if horizontal_flip: - crop_coordinates = horizontally_flip_bbox(crop_coordinates) - extra = self._crop_encoder(crop_coordinates) - else: - annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) - extra = [] - - object_tuples = self._make_object_descriptors(annotations) - flattened = [token for tuple_ in object_tuples for token in tuple_] + extra - assert len(flattened) == self.embedding_dim - assert all(0 <= value < self.no_tokens for value in flattened) - return LongTensor(flattened) diff --git a/Control-Color/taming/data/conditional_builder/utils.py b/Control-Color/taming/data/conditional_builder/utils.py deleted file mode 100644 index d0ee175f2e05a80dbc71c22acbecb22dddadbb42..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/conditional_builder/utils.py +++ /dev/null @@ -1,105 +0,0 @@ -import importlib -from typing import List, Any, Tuple, Optional - -from taming.data.helper_types import BoundingBox, Annotation - -# source: seaborn, color palette tab10 -COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), - (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] -BLACK = (0, 0, 0) -GRAY_75 = (63, 63, 63) -GRAY_50 = (127, 127, 127) -GRAY_25 = (191, 191, 191) -WHITE = (255, 255, 255) -FULL_CROP = (0., 0., 1., 1.) - - -def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: - """ - Give intersection area of two rectangles. - @param rectangle1: (x0, y0, w, h) of first rectangle - @param rectangle2: (x0, y0, w, h) of second rectangle - """ - rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] - rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] - x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) - y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) - return x_overlap * y_overlap - - -def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: - return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] - - -def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: - bbox = relative_bbox - bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height - return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) - - -def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: - return list_ + [pad_element for _ in range(pad_to_length - len(list_))] - - -def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ - List[Annotation]: - def clamp(x: float): - return max(min(x, 1.), 0.) - - def rescale_bbox(bbox: BoundingBox) -> BoundingBox: - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - if flip: - x0 = 1 - (x0 + w) - return x0, y0, w, h - - return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] - - -def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: - return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] - - -def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: - sl = slice(1) if short else slice(None) - string = '' - if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): - return string - if annotation.is_group_of: - string += 'group'[sl] + ',' - if annotation.is_occluded: - string += 'occluded'[sl] + ',' - if annotation.is_depiction: - string += 'depiction'[sl] + ',' - if annotation.is_inside: - string += 'inside'[sl] - return '(' + string.strip(",") + ')' - - -def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: - if font_size is None: - font_size = 10 - if max(figure_size) >= 256: - font_size = 12 - if max(figure_size) >= 512: - font_size = 15 - return font_size - - -def get_circle_size(figure_size: Tuple[int, int]) -> int: - circle_size = 2 - if max(figure_size) >= 256: - circle_size = 3 - if max(figure_size) >= 512: - circle_size = 4 - return circle_size - - -def load_object_from_string(object_string: str) -> Any: - """ - Source: https://stackoverflow.com/a/10773699 - """ - module_name, class_name = object_string.rsplit(".", 1) - return getattr(importlib.import_module(module_name), class_name) diff --git a/Control-Color/taming/data/custom.py b/Control-Color/taming/data/custom.py deleted file mode 100644 index 33f302a4b55ba1e8ec282ec3292b6263c06dfb91..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/custom.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import numpy as np -import albumentations -from torch.utils.data import Dataset - -from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex - - -class CustomBase(Dataset): - def __init__(self, *args, **kwargs): - super().__init__() - self.data = None - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - example = self.data[i] - return example - - - -class CustomTrain(CustomBase): - def __init__(self, size, training_images_list_file): - super().__init__() - with open(training_images_list_file, "r") as f: - paths = f.read().splitlines() - self.data = ImagePaths(paths=paths, size=size, random_crop=False) - - -class CustomTest(CustomBase): - def __init__(self, size, test_images_list_file): - super().__init__() - with open(test_images_list_file, "r") as f: - paths = f.read().splitlines() - self.data = ImagePaths(paths=paths, size=size, random_crop=False) - - diff --git a/Control-Color/taming/data/faceshq.py b/Control-Color/taming/data/faceshq.py deleted file mode 100644 index 6912d04b66a6d464c1078e4b51d5da290f5e767e..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/faceshq.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import numpy as np -import albumentations -from torch.utils.data import Dataset - -from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex - - -class FacesBase(Dataset): - def __init__(self, *args, **kwargs): - super().__init__() - self.data = None - self.keys = None - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - example = self.data[i] - ex = {} - if self.keys is not None: - for k in self.keys: - ex[k] = example[k] - else: - ex = example - return ex - - -class CelebAHQTrain(FacesBase): - def __init__(self, size, keys=None): - super().__init__() - root = "data/celebahq" - with open("data/celebahqtrain.txt", "r") as f: - relpaths = f.read().splitlines() - paths = [os.path.join(root, relpath) for relpath in relpaths] - self.data = NumpyPaths(paths=paths, size=size, random_crop=False) - self.keys = keys - - -class CelebAHQValidation(FacesBase): - def __init__(self, size, keys=None): - super().__init__() - root = "data/celebahq" - with open("data/celebahqvalidation.txt", "r") as f: - relpaths = f.read().splitlines() - paths = [os.path.join(root, relpath) for relpath in relpaths] - self.data = NumpyPaths(paths=paths, size=size, random_crop=False) - self.keys = keys - - -class FFHQTrain(FacesBase): - def __init__(self, size, keys=None): - super().__init__() - root = "data/ffhq" - with open("data/ffhqtrain.txt", "r") as f: - relpaths = f.read().splitlines() - paths = [os.path.join(root, relpath) for relpath in relpaths] - self.data = ImagePaths(paths=paths, size=size, random_crop=False) - self.keys = keys - - -class FFHQValidation(FacesBase): - def __init__(self, size, keys=None): - super().__init__() - root = "data/ffhq" - with open("data/ffhqvalidation.txt", "r") as f: - relpaths = f.read().splitlines() - paths = [os.path.join(root, relpath) for relpath in relpaths] - self.data = ImagePaths(paths=paths, size=size, random_crop=False) - self.keys = keys - - -class FacesHQTrain(Dataset): - # CelebAHQ [0] + FFHQ [1] - def __init__(self, size, keys=None, crop_size=None, coord=False): - d1 = CelebAHQTrain(size=size, keys=keys) - d2 = FFHQTrain(size=size, keys=keys) - self.data = ConcatDatasetWithIndex([d1, d2]) - self.coord = coord - if crop_size is not None: - self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) - if self.coord: - self.cropper = albumentations.Compose([self.cropper], - additional_targets={"coord": "image"}) - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - ex, y = self.data[i] - if hasattr(self, "cropper"): - if not self.coord: - out = self.cropper(image=ex["image"]) - ex["image"] = out["image"] - else: - h,w,_ = ex["image"].shape - coord = np.arange(h*w).reshape(h,w,1)/(h*w) - out = self.cropper(image=ex["image"], coord=coord) - ex["image"] = out["image"] - ex["coord"] = out["coord"] - ex["class"] = y - return ex - - -class FacesHQValidation(Dataset): - # CelebAHQ [0] + FFHQ [1] - def __init__(self, size, keys=None, crop_size=None, coord=False): - d1 = CelebAHQValidation(size=size, keys=keys) - d2 = FFHQValidation(size=size, keys=keys) - self.data = ConcatDatasetWithIndex([d1, d2]) - self.coord = coord - if crop_size is not None: - self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) - if self.coord: - self.cropper = albumentations.Compose([self.cropper], - additional_targets={"coord": "image"}) - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - ex, y = self.data[i] - if hasattr(self, "cropper"): - if not self.coord: - out = self.cropper(image=ex["image"]) - ex["image"] = out["image"] - else: - h,w,_ = ex["image"].shape - coord = np.arange(h*w).reshape(h,w,1)/(h*w) - out = self.cropper(image=ex["image"], coord=coord) - ex["image"] = out["image"] - ex["coord"] = out["coord"] - ex["class"] = y - return ex diff --git a/Control-Color/taming/data/helper_types.py b/Control-Color/taming/data/helper_types.py deleted file mode 100644 index fb51e301da08602cfead5961c4f7e1d89f6aba79..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/helper_types.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Dict, Tuple, Optional, NamedTuple, Union -from PIL.Image import Image as pil_image -from torch import Tensor - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - -Image = Union[Tensor, pil_image] -BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h -CropMethodType = Literal['none', 'random', 'center', 'random-2d'] -SplitType = Literal['train', 'validation', 'test'] - - -class ImageDescription(NamedTuple): - id: int - file_name: str - original_size: Tuple[int, int] # w, h - url: Optional[str] = None - license: Optional[int] = None - coco_url: Optional[str] = None - date_captured: Optional[str] = None - flickr_url: Optional[str] = None - flickr_id: Optional[str] = None - coco_id: Optional[str] = None - - -class Category(NamedTuple): - id: str - super_category: Optional[str] - name: str - - -class Annotation(NamedTuple): - area: float - image_id: str - bbox: BoundingBox - category_no: int - category_id: str - id: Optional[int] = None - source: Optional[str] = None - confidence: Optional[float] = None - is_group_of: Optional[bool] = None - is_truncated: Optional[bool] = None - is_occluded: Optional[bool] = None - is_depiction: Optional[bool] = None - is_inside: Optional[bool] = None - segmentation: Optional[Dict] = None diff --git a/Control-Color/taming/data/image_transforms.py b/Control-Color/taming/data/image_transforms.py deleted file mode 100644 index 657ac332174e0ac72f68315271ffbd757b771a0f..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/image_transforms.py +++ /dev/null @@ -1,132 +0,0 @@ -import random -import warnings -from typing import Union - -import torch -from torch import Tensor -from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor -from torchvision.transforms.functional import _get_image_size as get_image_size - -from taming.data.helper_types import BoundingBox, Image - -pil_to_tensor = PILToTensor() - - -def convert_pil_to_tensor(image: Image) -> Tensor: - with warnings.catch_warnings(): - # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 - warnings.simplefilter("ignore") - return pil_to_tensor(image) - - -class RandomCrop1dReturnCoordinates(RandomCrop): - def forward(self, img: Image) -> (BoundingBox, Image): - """ - Additionally to cropping, returns the relative coordinates of the crop bounding box. - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - Bounding box: x0, y0, w, h - PIL Image or Tensor: Cropped image. - - Based on: - torchvision.transforms.RandomCrop, torchvision 1.7.0 - """ - if self.padding is not None: - img = F.pad(img, self.padding, self.fill, self.padding_mode) - - width, height = get_image_size(img) - # pad the width if needed - if self.pad_if_needed and width < self.size[1]: - padding = [self.size[1] - width, 0] - img = F.pad(img, padding, self.fill, self.padding_mode) - # pad the height if needed - if self.pad_if_needed and height < self.size[0]: - padding = [0, self.size[0] - height] - img = F.pad(img, padding, self.fill, self.padding_mode) - - i, j, h, w = self.get_params(img, self.size) - bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h - return bbox, F.crop(img, i, j, h, w) - - -class Random2dCropReturnCoordinates(torch.nn.Module): - """ - Additionally to cropping, returns the relative coordinates of the crop bounding box. - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - Bounding box: x0, y0, w, h - PIL Image or Tensor: Cropped image. - - Based on: - torchvision.transforms.RandomCrop, torchvision 1.7.0 - """ - - def __init__(self, min_size: int): - super().__init__() - self.min_size = min_size - - def forward(self, img: Image) -> (BoundingBox, Image): - width, height = get_image_size(img) - max_size = min(width, height) - if max_size <= self.min_size: - size = max_size - else: - size = random.randint(self.min_size, max_size) - top = random.randint(0, height - size) - left = random.randint(0, width - size) - bbox = left / width, top / height, size / width, size / height - return bbox, F.crop(img, top, left, size, size) - - -class CenterCropReturnCoordinates(CenterCrop): - @staticmethod - def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: - if width > height: - w = height / width - h = 1.0 - x0 = 0.5 - w / 2 - y0 = 0. - else: - w = 1.0 - h = width / height - x0 = 0. - y0 = 0.5 - h / 2 - return x0, y0, w, h - - def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): - """ - Additionally to cropping, returns the relative coordinates of the crop bounding box. - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - Bounding box: x0, y0, w, h - PIL Image or Tensor: Cropped image. - Based on: - torchvision.transforms.RandomHorizontalFlip (version 1.7.0) - """ - width, height = get_image_size(img) - return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) - - -class RandomHorizontalFlipReturn(RandomHorizontalFlip): - def forward(self, img: Image) -> (bool, Image): - """ - Additionally to flipping, returns a boolean whether it was flipped or not. - Args: - img (PIL Image or Tensor): Image to be flipped. - - Returns: - flipped: whether the image was flipped or not - PIL Image or Tensor: Randomly flipped image. - - Based on: - torchvision.transforms.RandomHorizontalFlip (version 1.7.0) - """ - if torch.rand(1) < self.p: - return True, F.hflip(img) - return False, img diff --git a/Control-Color/taming/data/imagenet.py b/Control-Color/taming/data/imagenet.py deleted file mode 100644 index 9a02ec44ba4af9e993f58c91fa43482a4ecbe54c..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/imagenet.py +++ /dev/null @@ -1,558 +0,0 @@ -import os, tarfile, glob, shutil -import yaml -import numpy as np -from tqdm import tqdm -from PIL import Image -import albumentations -from omegaconf import OmegaConf -from torch.utils.data import Dataset - -from taming.data.base import ImagePaths -from taming.util import download, retrieve -import taming.data.utils as bdu - - -def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): - synsets = [] - with open(path_to_yaml) as f: - di2s = yaml.load(f) - for idx in indices: - synsets.append(str(di2s[idx])) - print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) - return synsets - - -def str_to_indices(string): - """Expects a string in the format '32-123, 256, 280-321'""" - assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) - subs = string.split(",") - indices = [] - for sub in subs: - subsubs = sub.split("-") - assert len(subsubs) > 0 - if len(subsubs) == 1: - indices.append(int(subsubs[0])) - else: - rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] - indices.extend(rang) - return sorted(indices) - - -class ImageNetBase(Dataset): - def __init__(self, config=None): - self.config = config or OmegaConf.create() - if not type(self.config)==dict: - self.config = OmegaConf.to_container(self.config) - self._prepare() - self._prepare_synset_to_human() - self._prepare_idx_to_synset() - self._load() - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - return self.data[i] - - def _prepare(self): - raise NotImplementedError() - - def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) - relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] - if "sub_indices" in self.config: - indices = str_to_indices(self.config["sub_indices"]) - synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings - files = [] - for rpath in relpaths: - syn = rpath.split("/")[0] - if syn in synsets: - files.append(rpath) - return files - else: - return relpaths - - def _prepare_synset_to_human(self): - SIZE = 2655750 - URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" - self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): - download(URL, self.human_dict) - - def _prepare_idx_to_synset(self): - URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" - self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): - download(URL, self.idx2syn) - - def _load(self): - with open(self.txt_filelist, "r") as f: - self.relpaths = f.read().splitlines() - l1 = len(self.relpaths) - self.relpaths = self._filter_relpaths(self.relpaths) - print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) - - self.synsets = [p.split("/")[0] for p in self.relpaths] - self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] - - unique_synsets = np.unique(self.synsets) - class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) - self.class_labels = [class_dict[s] for s in self.synsets] - - with open(self.human_dict, "r") as f: - human_dict = f.read().splitlines() - human_dict = dict(line.split(maxsplit=1) for line in human_dict) - - self.human_labels = [human_dict[s] for s in self.synsets] - - labels = { - "relpath": np.array(self.relpaths), - "synsets": np.array(self.synsets), - "class_label": np.array(self.class_labels), - "human_label": np.array(self.human_labels), - } - self.data = ImagePaths(self.abspaths, - labels=labels, - size=retrieve(self.config, "size", default=0), - random_crop=self.random_crop) - - -class ImageNetTrain(ImageNetBase): - NAME = "ILSVRC2012_train" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" - FILES = [ - "ILSVRC2012_img_train.tar", - ] - SIZES = [ - 147897477120, - ] - - def _prepare(self): - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") - self.expected_length = 1281167 - if not bdu.is_prepared(self.root): - # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) - - datadir = self.datadir - if not os.path.exists(datadir): - path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: - import academictorrents as at - atpath = at.get(self.AT_HASH, datastore=self.root) - assert atpath == path - - print("Extracting {} to {}".format(path, datadir)) - os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) - - print("Extracting sub-tars.") - subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) - for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] - os.makedirs(subdir, exist_ok=True) - with tarfile.open(subpath, "r:") as tar: - tar.extractall(path=subdir) - - - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) - filelist = [os.path.relpath(p, start=datadir) for p in filelist] - filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" - with open(self.txt_filelist, "w") as f: - f.write(filelist) - - bdu.mark_prepared(self.root) - - -class ImageNetValidation(ImageNetBase): - NAME = "ILSVRC2012_validation" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" - VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" - FILES = [ - "ILSVRC2012_img_val.tar", - "validation_synset.txt", - ] - SIZES = [ - 6744924160, - 1950000, - ] - - def _prepare(self): - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") - self.expected_length = 50000 - if not bdu.is_prepared(self.root): - # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) - - datadir = self.datadir - if not os.path.exists(datadir): - path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: - import academictorrents as at - atpath = at.get(self.AT_HASH, datastore=self.root) - assert atpath == path - - print("Extracting {} to {}".format(path, datadir)) - os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) - - vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: - download(self.VS_URL, vspath) - - with open(vspath, "r") as f: - synset_dict = f.read().splitlines() - synset_dict = dict(line.split() for line in synset_dict) - - print("Reorganizing into synset folders") - synsets = np.unique(list(synset_dict.values())) - for s in synsets: - os.makedirs(os.path.join(datadir, s), exist_ok=True) - for k, v in synset_dict.items(): - src = os.path.join(datadir, k) - dst = os.path.join(datadir, v) - shutil.move(src, dst) - - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) - filelist = [os.path.relpath(p, start=datadir) for p in filelist] - filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" - with open(self.txt_filelist, "w") as f: - f.write(filelist) - - bdu.mark_prepared(self.root) - - -def get_preprocessor(size=None, random_crop=False, additional_targets=None, - crop_size=None): - if size is not None and size > 0: - transforms = list() - rescaler = albumentations.SmallestMaxSize(max_size = size) - transforms.append(rescaler) - if not random_crop: - cropper = albumentations.CenterCrop(height=size,width=size) - transforms.append(cropper) - else: - cropper = albumentations.RandomCrop(height=size,width=size) - transforms.append(cropper) - flipper = albumentations.HorizontalFlip() - transforms.append(flipper) - preprocessor = albumentations.Compose(transforms, - additional_targets=additional_targets) - elif crop_size is not None and crop_size > 0: - if not random_crop: - cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) - else: - cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) - transforms = [cropper] - preprocessor = albumentations.Compose(transforms, - additional_targets=additional_targets) - else: - preprocessor = lambda **kwargs: kwargs - return preprocessor - - -def rgba_to_depth(x): - assert x.dtype == np.uint8 - assert len(x.shape) == 3 and x.shape[2] == 4 - y = x.copy() - y.dtype = np.float32 - y = y.reshape(x.shape[:2]) - return np.ascontiguousarray(y) - - -class BaseWithDepth(Dataset): - DEFAULT_DEPTH_ROOT="data/imagenet_depth" - - def __init__(self, config=None, size=None, random_crop=False, - crop_size=None, root=None): - self.config = config - self.base_dset = self.get_base_dset() - self.preprocessor = get_preprocessor( - size=size, - crop_size=crop_size, - random_crop=random_crop, - additional_targets={"depth": "image"}) - self.crop_size = crop_size - if self.crop_size is not None: - self.rescaler = albumentations.Compose( - [albumentations.SmallestMaxSize(max_size = self.crop_size)], - additional_targets={"depth": "image"}) - if root is not None: - self.DEFAULT_DEPTH_ROOT = root - - def __len__(self): - return len(self.base_dset) - - def preprocess_depth(self, path): - rgba = np.array(Image.open(path)) - depth = rgba_to_depth(rgba) - depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) - depth = 2.0*depth-1.0 - return depth - - def __getitem__(self, i): - e = self.base_dset[i] - e["depth"] = self.preprocess_depth(self.get_depth_path(e)) - # up if necessary - h,w,c = e["image"].shape - if self.crop_size and min(h,w) < self.crop_size: - # have to upscale to be able to crop - this just uses bilinear - out = self.rescaler(image=e["image"], depth=e["depth"]) - e["image"] = out["image"] - e["depth"] = out["depth"] - transformed = self.preprocessor(image=e["image"], depth=e["depth"]) - e["image"] = transformed["image"] - e["depth"] = transformed["depth"] - return e - - -class ImageNetTrainWithDepth(BaseWithDepth): - # default to random_crop=True - def __init__(self, random_crop=True, sub_indices=None, **kwargs): - self.sub_indices = sub_indices - super().__init__(random_crop=random_crop, **kwargs) - - def get_base_dset(self): - if self.sub_indices is None: - return ImageNetTrain() - else: - return ImageNetTrain({"sub_indices": self.sub_indices}) - - def get_depth_path(self, e): - fid = os.path.splitext(e["relpath"])[0]+".png" - fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) - return fid - - -class ImageNetValidationWithDepth(BaseWithDepth): - def __init__(self, sub_indices=None, **kwargs): - self.sub_indices = sub_indices - super().__init__(**kwargs) - - def get_base_dset(self): - if self.sub_indices is None: - return ImageNetValidation() - else: - return ImageNetValidation({"sub_indices": self.sub_indices}) - - def get_depth_path(self, e): - fid = os.path.splitext(e["relpath"])[0]+".png" - fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) - return fid - - -class RINTrainWithDepth(ImageNetTrainWithDepth): - def __init__(self, config=None, size=None, random_crop=True, crop_size=None): - sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" - super().__init__(config=config, size=size, random_crop=random_crop, - sub_indices=sub_indices, crop_size=crop_size) - - -class RINValidationWithDepth(ImageNetValidationWithDepth): - def __init__(self, config=None, size=None, random_crop=False, crop_size=None): - sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" - super().__init__(config=config, size=size, random_crop=random_crop, - sub_indices=sub_indices, crop_size=crop_size) - - -class DRINExamples(Dataset): - def __init__(self): - self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) - with open("data/drin_examples.txt", "r") as f: - relpaths = f.read().splitlines() - self.image_paths = [os.path.join("data/drin_images", - relpath) for relpath in relpaths] - self.depth_paths = [os.path.join("data/drin_depth", - relpath.replace(".JPEG", ".png")) for relpath in relpaths] - - def __len__(self): - return len(self.image_paths) - - def preprocess_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - image = np.array(image).astype(np.uint8) - image = self.preprocessor(image=image)["image"] - image = (image/127.5 - 1.0).astype(np.float32) - return image - - def preprocess_depth(self, path): - rgba = np.array(Image.open(path)) - depth = rgba_to_depth(rgba) - depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) - depth = 2.0*depth-1.0 - return depth - - def __getitem__(self, i): - e = dict() - e["image"] = self.preprocess_image(self.image_paths[i]) - e["depth"] = self.preprocess_depth(self.depth_paths[i]) - transformed = self.preprocessor(image=e["image"], depth=e["depth"]) - e["image"] = transformed["image"] - e["depth"] = transformed["depth"] - return e - - -def imscale(x, factor, keepshapes=False, keepmode="bicubic"): - if factor is None or factor==1: - return x - - dtype = x.dtype - assert dtype in [np.float32, np.float64] - assert x.min() >= -1 - assert x.max() <= 1 - - keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, - "bicubic": Image.BICUBIC}[keepmode] - - lr = (x+1.0)*127.5 - lr = lr.clip(0,255).astype(np.uint8) - lr = Image.fromarray(lr) - - h, w, _ = x.shape - nh = h//factor - nw = w//factor - assert nh > 0 and nw > 0, (nh, nw) - - lr = lr.resize((nw,nh), Image.BICUBIC) - if keepshapes: - lr = lr.resize((w,h), keepmode) - lr = np.array(lr)/127.5-1.0 - lr = lr.astype(dtype) - - return lr - - -class ImageNetScale(Dataset): - def __init__(self, size=None, crop_size=None, random_crop=False, - up_factor=None, hr_factor=None, keep_mode="bicubic"): - self.base = self.get_base() - - self.size = size - self.crop_size = crop_size if crop_size is not None else self.size - self.random_crop = random_crop - self.up_factor = up_factor - self.hr_factor = hr_factor - self.keep_mode = keep_mode - - transforms = list() - - if self.size is not None and self.size > 0: - rescaler = albumentations.SmallestMaxSize(max_size = self.size) - self.rescaler = rescaler - transforms.append(rescaler) - - if self.crop_size is not None and self.crop_size > 0: - if len(transforms) == 0: - self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) - - if not self.random_crop: - cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) - else: - cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) - transforms.append(cropper) - - if len(transforms) > 0: - if self.up_factor is not None: - additional_targets = {"lr": "image"} - else: - additional_targets = None - self.preprocessor = albumentations.Compose(transforms, - additional_targets=additional_targets) - else: - self.preprocessor = lambda **kwargs: kwargs - - def __len__(self): - return len(self.base) - - def __getitem__(self, i): - example = self.base[i] - image = example["image"] - # adjust resolution - image = imscale(image, self.hr_factor, keepshapes=False) - h,w,c = image.shape - if self.crop_size and min(h,w) < self.crop_size: - # have to upscale to be able to crop - this just uses bilinear - image = self.rescaler(image=image)["image"] - if self.up_factor is None: - image = self.preprocessor(image=image)["image"] - example["image"] = image - else: - lr = imscale(image, self.up_factor, keepshapes=True, - keepmode=self.keep_mode) - - out = self.preprocessor(image=image, lr=lr) - example["image"] = out["image"] - example["lr"] = out["lr"] - - return example - -class ImageNetScaleTrain(ImageNetScale): - def __init__(self, random_crop=True, **kwargs): - super().__init__(random_crop=random_crop, **kwargs) - - def get_base(self): - return ImageNetTrain() - -class ImageNetScaleValidation(ImageNetScale): - def get_base(self): - return ImageNetValidation() - - -from skimage.feature import canny -from skimage.color import rgb2gray - - -class ImageNetEdges(ImageNetScale): - def __init__(self, up_factor=1, **kwargs): - super().__init__(up_factor=1, **kwargs) - - def __getitem__(self, i): - example = self.base[i] - image = example["image"] - h,w,c = image.shape - if self.crop_size and min(h,w) < self.crop_size: - # have to upscale to be able to crop - this just uses bilinear - image = self.rescaler(image=image)["image"] - - lr = canny(rgb2gray(image), sigma=2) - lr = lr.astype(np.float32) - lr = lr[:,:,None][:,:,[0,0,0]] - - out = self.preprocessor(image=image, lr=lr) - example["image"] = out["image"] - example["lr"] = out["lr"] - - return example - - -class ImageNetEdgesTrain(ImageNetEdges): - def __init__(self, random_crop=True, **kwargs): - super().__init__(random_crop=random_crop, **kwargs) - - def get_base(self): - return ImageNetTrain() - -class ImageNetEdgesValidation(ImageNetEdges): - def get_base(self): - return ImageNetValidation() diff --git a/Control-Color/taming/data/open_images_helper.py b/Control-Color/taming/data/open_images_helper.py deleted file mode 100644 index 8feb7c6e705fc165d2983303192aaa88f579b243..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/open_images_helper.py +++ /dev/null @@ -1,379 +0,0 @@ -open_images_unify_categories_for_coco = { - '/m/03bt1vf': '/m/01g317', - '/m/04yx4': '/m/01g317', - '/m/05r655': '/m/01g317', - '/m/01bl7v': '/m/01g317', - '/m/0cnyhnx': '/m/01xq0k1', - '/m/01226z': '/m/018xm', - '/m/05ctyq': '/m/018xm', - '/m/058qzx': '/m/04ctx', - '/m/06pcq': '/m/0l515', - '/m/03m3pdh': '/m/02crq1', - '/m/046dlr': '/m/01x3z', - '/m/0h8mzrc': '/m/01x3z', -} - - -top_300_classes_plus_coco_compatibility = [ - ('Man', 1060962), - ('Clothing', 986610), - ('Tree', 748162), - ('Woman', 611896), - ('Person', 610294), - ('Human face', 442948), - ('Girl', 175399), - ('Building', 162147), - ('Car', 159135), - ('Plant', 155704), - ('Human body', 137073), - ('Flower', 133128), - ('Window', 127485), - ('Human arm', 118380), - ('House', 114365), - ('Wheel', 111684), - ('Suit', 99054), - ('Human hair', 98089), - ('Human head', 92763), - ('Chair', 88624), - ('Boy', 79849), - ('Table', 73699), - ('Jeans', 57200), - ('Tire', 55725), - ('Skyscraper', 53321), - ('Food', 52400), - ('Footwear', 50335), - ('Dress', 50236), - ('Human leg', 47124), - ('Toy', 46636), - ('Tower', 45605), - ('Boat', 43486), - ('Land vehicle', 40541), - ('Bicycle wheel', 34646), - ('Palm tree', 33729), - ('Fashion accessory', 32914), - ('Glasses', 31940), - ('Bicycle', 31409), - ('Furniture', 30656), - ('Sculpture', 29643), - ('Bottle', 27558), - ('Dog', 26980), - ('Snack', 26796), - ('Human hand', 26664), - ('Bird', 25791), - ('Book', 25415), - ('Guitar', 24386), - ('Jacket', 23998), - ('Poster', 22192), - ('Dessert', 21284), - ('Baked goods', 20657), - ('Drink', 19754), - ('Flag', 18588), - ('Houseplant', 18205), - ('Tableware', 17613), - ('Airplane', 17218), - ('Door', 17195), - ('Sports uniform', 17068), - ('Shelf', 16865), - ('Drum', 16612), - ('Vehicle', 16542), - ('Microphone', 15269), - ('Street light', 14957), - ('Cat', 14879), - ('Fruit', 13684), - ('Fast food', 13536), - ('Animal', 12932), - ('Vegetable', 12534), - ('Train', 12358), - ('Horse', 11948), - ('Flowerpot', 11728), - ('Motorcycle', 11621), - ('Fish', 11517), - ('Desk', 11405), - ('Helmet', 10996), - ('Truck', 10915), - ('Bus', 10695), - ('Hat', 10532), - ('Auto part', 10488), - ('Musical instrument', 10303), - ('Sunglasses', 10207), - ('Picture frame', 10096), - ('Sports equipment', 10015), - ('Shorts', 9999), - ('Wine glass', 9632), - ('Duck', 9242), - ('Wine', 9032), - ('Rose', 8781), - ('Tie', 8693), - ('Butterfly', 8436), - ('Beer', 7978), - ('Cabinetry', 7956), - ('Laptop', 7907), - ('Insect', 7497), - ('Goggles', 7363), - ('Shirt', 7098), - ('Dairy Product', 7021), - ('Marine invertebrates', 7014), - ('Cattle', 7006), - ('Trousers', 6903), - ('Van', 6843), - ('Billboard', 6777), - ('Balloon', 6367), - ('Human nose', 6103), - ('Tent', 6073), - ('Camera', 6014), - ('Doll', 6002), - ('Coat', 5951), - ('Mobile phone', 5758), - ('Swimwear', 5729), - ('Strawberry', 5691), - ('Stairs', 5643), - ('Goose', 5599), - ('Umbrella', 5536), - ('Cake', 5508), - ('Sun hat', 5475), - ('Bench', 5310), - ('Bookcase', 5163), - ('Bee', 5140), - ('Computer monitor', 5078), - ('Hiking equipment', 4983), - ('Office building', 4981), - ('Coffee cup', 4748), - ('Curtain', 4685), - ('Plate', 4651), - ('Box', 4621), - ('Tomato', 4595), - ('Coffee table', 4529), - ('Office supplies', 4473), - ('Maple', 4416), - ('Muffin', 4365), - ('Cocktail', 4234), - ('Castle', 4197), - ('Couch', 4134), - ('Pumpkin', 3983), - ('Computer keyboard', 3960), - ('Human mouth', 3926), - ('Christmas tree', 3893), - ('Mushroom', 3883), - ('Swimming pool', 3809), - ('Pastry', 3799), - ('Lavender (Plant)', 3769), - ('Football helmet', 3732), - ('Bread', 3648), - ('Traffic sign', 3628), - ('Common sunflower', 3597), - ('Television', 3550), - ('Bed', 3525), - ('Cookie', 3485), - ('Fountain', 3484), - ('Paddle', 3447), - ('Bicycle helmet', 3429), - ('Porch', 3420), - ('Deer', 3387), - ('Fedora', 3339), - ('Canoe', 3338), - ('Carnivore', 3266), - ('Bowl', 3202), - ('Human eye', 3166), - ('Ball', 3118), - ('Pillow', 3077), - ('Salad', 3061), - ('Beetle', 3060), - ('Orange', 3050), - ('Drawer', 2958), - ('Platter', 2937), - ('Elephant', 2921), - ('Seafood', 2921), - ('Monkey', 2915), - ('Countertop', 2879), - ('Watercraft', 2831), - ('Helicopter', 2805), - ('Kitchen appliance', 2797), - ('Personal flotation device', 2781), - ('Swan', 2739), - ('Lamp', 2711), - ('Boot', 2695), - ('Bronze sculpture', 2693), - ('Chicken', 2677), - ('Taxi', 2643), - ('Juice', 2615), - ('Cowboy hat', 2604), - ('Apple', 2600), - ('Tin can', 2590), - ('Necklace', 2564), - ('Ice cream', 2560), - ('Human beard', 2539), - ('Coin', 2536), - ('Candle', 2515), - ('Cart', 2512), - ('High heels', 2441), - ('Weapon', 2433), - ('Handbag', 2406), - ('Penguin', 2396), - ('Rifle', 2352), - ('Violin', 2336), - ('Skull', 2304), - ('Lantern', 2285), - ('Scarf', 2269), - ('Saucer', 2225), - ('Sheep', 2215), - ('Vase', 2189), - ('Lily', 2180), - ('Mug', 2154), - ('Parrot', 2140), - ('Human ear', 2137), - ('Sandal', 2115), - ('Lizard', 2100), - ('Kitchen & dining room table', 2063), - ('Spider', 1977), - ('Coffee', 1974), - ('Goat', 1926), - ('Squirrel', 1922), - ('Cello', 1913), - ('Sushi', 1881), - ('Tortoise', 1876), - ('Pizza', 1870), - ('Studio couch', 1864), - ('Barrel', 1862), - ('Cosmetics', 1841), - ('Moths and butterflies', 1841), - ('Convenience store', 1817), - ('Watch', 1792), - ('Home appliance', 1786), - ('Harbor seal', 1780), - ('Luggage and bags', 1756), - ('Vehicle registration plate', 1754), - ('Shrimp', 1751), - ('Jellyfish', 1730), - ('French fries', 1723), - ('Egg (Food)', 1698), - ('Football', 1697), - ('Musical keyboard', 1683), - ('Falcon', 1674), - ('Candy', 1660), - ('Medical equipment', 1654), - ('Eagle', 1651), - ('Dinosaur', 1634), - ('Surfboard', 1630), - ('Tank', 1628), - ('Grape', 1624), - ('Lion', 1624), - ('Owl', 1622), - ('Ski', 1613), - ('Waste container', 1606), - ('Frog', 1591), - ('Sparrow', 1585), - ('Rabbit', 1581), - ('Pen', 1546), - ('Sea lion', 1537), - ('Spoon', 1521), - ('Sink', 1512), - ('Teddy bear', 1507), - ('Bull', 1495), - ('Sofa bed', 1490), - ('Dragonfly', 1479), - ('Brassiere', 1478), - ('Chest of drawers', 1472), - ('Aircraft', 1466), - ('Human foot', 1463), - ('Pig', 1455), - ('Fork', 1454), - ('Antelope', 1438), - ('Tripod', 1427), - ('Tool', 1424), - ('Cheese', 1422), - ('Lemon', 1397), - ('Hamburger', 1393), - ('Dolphin', 1390), - ('Mirror', 1390), - ('Marine mammal', 1387), - ('Giraffe', 1385), - ('Snake', 1368), - ('Gondola', 1364), - ('Wheelchair', 1360), - ('Piano', 1358), - ('Cupboard', 1348), - ('Banana', 1345), - ('Trumpet', 1335), - ('Lighthouse', 1333), - ('Invertebrate', 1317), - ('Carrot', 1268), - ('Sock', 1260), - ('Tiger', 1241), - ('Camel', 1224), - ('Parachute', 1224), - ('Bathroom accessory', 1223), - ('Earrings', 1221), - ('Headphones', 1218), - ('Skirt', 1198), - ('Skateboard', 1190), - ('Sandwich', 1148), - ('Saxophone', 1141), - ('Goldfish', 1136), - ('Stool', 1104), - ('Traffic light', 1097), - ('Shellfish', 1081), - ('Backpack', 1079), - ('Sea turtle', 1078), - ('Cucumber', 1075), - ('Tea', 1051), - ('Toilet', 1047), - ('Roller skates', 1040), - ('Mule', 1039), - ('Bust', 1031), - ('Broccoli', 1030), - ('Crab', 1020), - ('Oyster', 1019), - ('Cannon', 1012), - ('Zebra', 1012), - ('French horn', 1008), - ('Grapefruit', 998), - ('Whiteboard', 997), - ('Zucchini', 997), - ('Crocodile', 992), - - ('Clock', 960), - ('Wall clock', 958), - - ('Doughnut', 869), - ('Snail', 868), - - ('Baseball glove', 859), - - ('Panda', 830), - ('Tennis racket', 830), - - ('Pear', 652), - - ('Bagel', 617), - ('Oven', 616), - ('Ladybug', 615), - ('Shark', 615), - ('Polar bear', 614), - ('Ostrich', 609), - - ('Hot dog', 473), - ('Microwave oven', 467), - ('Fire hydrant', 20), - ('Stop sign', 20), - ('Parking meter', 20), - ('Bear', 20), - ('Flying disc', 20), - ('Snowboard', 20), - ('Tennis ball', 20), - ('Kite', 20), - ('Baseball bat', 20), - ('Kitchen knife', 20), - ('Knife', 20), - ('Submarine sandwich', 20), - ('Computer mouse', 20), - ('Remote control', 20), - ('Toaster', 20), - ('Sink', 20), - ('Refrigerator', 20), - ('Alarm clock', 20), - ('Wall clock', 20), - ('Scissors', 20), - ('Hair dryer', 20), - ('Toothbrush', 20), - ('Suitcase', 20) -] diff --git a/Control-Color/taming/data/sflckr.py b/Control-Color/taming/data/sflckr.py deleted file mode 100644 index 91101be5953b113f1e58376af637e43f366b3dee..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/sflckr.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import numpy as np -import cv2 -import albumentations -from PIL import Image -from torch.utils.data import Dataset - - -class SegmentationBase(Dataset): - def __init__(self, - data_csv, data_root, segmentation_root, - size=None, random_crop=False, interpolation="bicubic", - n_labels=182, shift_segmentation=False, - ): - self.n_labels = n_labels - self.shift_segmentation = shift_segmentation - self.data_csv = data_csv - self.data_root = data_root - self.segmentation_root = segmentation_root - with open(self.data_csv, "r") as f: - self.image_paths = f.read().splitlines() - self._length = len(self.image_paths) - self.labels = { - "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], - "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) - for l in self.image_paths] - } - - size = None if size is not None and size<=0 else size - self.size = size - if self.size is not None: - self.interpolation = interpolation - self.interpolation = { - "nearest": cv2.INTER_NEAREST, - "bilinear": cv2.INTER_LINEAR, - "bicubic": cv2.INTER_CUBIC, - "area": cv2.INTER_AREA, - "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] - self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=self.interpolation) - self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, - interpolation=cv2.INTER_NEAREST) - self.center_crop = not random_crop - if self.center_crop: - self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) - else: - self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) - self.preprocessor = self.cropper - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = dict((k, self.labels[k][i]) for k in self.labels) - image = Image.open(example["file_path_"]) - if not image.mode == "RGB": - image = image.convert("RGB") - image = np.array(image).astype(np.uint8) - if self.size is not None: - image = self.image_rescaler(image=image)["image"] - segmentation = Image.open(example["segmentation_path_"]) - assert segmentation.mode == "L", segmentation.mode - segmentation = np.array(segmentation).astype(np.uint8) - if self.shift_segmentation: - # used to support segmentations containing unlabeled==255 label - segmentation = segmentation+1 - if self.size is not None: - segmentation = self.segmentation_rescaler(image=segmentation)["image"] - if self.size is not None: - processed = self.preprocessor(image=image, - mask=segmentation - ) - else: - processed = {"image": image, - "mask": segmentation - } - example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) - segmentation = processed["mask"] - onehot = np.eye(self.n_labels)[segmentation] - example["segmentation"] = onehot - return example - - -class Examples(SegmentationBase): - def __init__(self, size=None, random_crop=False, interpolation="bicubic"): - super().__init__(data_csv="data/sflckr_examples.txt", - data_root="data/sflckr_images", - segmentation_root="data/sflckr_segmentations", - size=size, random_crop=random_crop, interpolation=interpolation) diff --git a/Control-Color/taming/data/utils.py b/Control-Color/taming/data/utils.py deleted file mode 100644 index 2b3c3d53cd2b6c72b481b59834cf809d3735b394..0000000000000000000000000000000000000000 --- a/Control-Color/taming/data/utils.py +++ /dev/null @@ -1,169 +0,0 @@ -import collections -import os -import tarfile -import urllib -import zipfile -from pathlib import Path - -import numpy as np -import torch -from taming.data.helper_types import Annotation -from torch._six import string_classes -from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format -from tqdm import tqdm - - -def unpack(path): - if path.endswith("tar.gz"): - with tarfile.open(path, "r:gz") as tar: - tar.extractall(path=os.path.split(path)[0]) - elif path.endswith("tar"): - with tarfile.open(path, "r:") as tar: - tar.extractall(path=os.path.split(path)[0]) - elif path.endswith("zip"): - with zipfile.ZipFile(path, "r") as f: - f.extractall(path=os.path.split(path)[0]) - else: - raise NotImplementedError( - "Unknown file extension: {}".format(os.path.splitext(path)[1]) - ) - - -def reporthook(bar): - """tqdm progress bar for downloads.""" - - def hook(b=1, bsize=1, tsize=None): - if tsize is not None: - bar.total = tsize - bar.update(b * bsize - bar.n) - - return hook - - -def get_root(name): - base = "data/" - root = os.path.join(base, name) - os.makedirs(root, exist_ok=True) - return root - - -def is_prepared(root): - return Path(root).joinpath(".ready").exists() - - -def mark_prepared(root): - Path(root).joinpath(".ready").touch() - - -def prompt_download(file_, source, target_dir, content_dir=None): - targetpath = os.path.join(target_dir, file_) - while not os.path.exists(targetpath): - if content_dir is not None and os.path.exists( - os.path.join(target_dir, content_dir) - ): - break - print( - "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) - ) - if content_dir is not None: - print( - "Or place its content into '{}'.".format( - os.path.join(target_dir, content_dir) - ) - ) - input("Press Enter when done...") - return targetpath - - -def download_url(file_, url, target_dir): - targetpath = os.path.join(target_dir, file_) - os.makedirs(target_dir, exist_ok=True) - with tqdm( - unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ - ) as bar: - urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) - return targetpath - - -def download_urls(urls, target_dir): - paths = dict() - for fname, url in urls.items(): - outpath = download_url(fname, url, target_dir) - paths[fname] = outpath - return paths - - -def quadratic_crop(x, bbox, alpha=1.0): - """bbox is xmin, ymin, xmax, ymax""" - im_h, im_w = x.shape[:2] - bbox = np.array(bbox, dtype=np.float32) - bbox = np.clip(bbox, 0, max(im_h, im_w)) - center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) - w = bbox[2] - bbox[0] - h = bbox[3] - bbox[1] - l = int(alpha * max(w, h)) - l = max(l, 2) - - required_padding = -1 * min( - center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) - ) - required_padding = int(np.ceil(required_padding)) - if required_padding > 0: - padding = [ - [required_padding, required_padding], - [required_padding, required_padding], - ] - padding += [[0, 0]] * (len(x.shape) - 2) - x = np.pad(x, padding, "reflect") - center = center[0] + required_padding, center[1] + required_padding - xmin = int(center[0] - l / 2) - ymin = int(center[1] - l / 2) - return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) - - -def custom_collate(batch): - r"""source: pytorch 1.9.0, only one modification to original code """ - - elem = batch[0] - elem_type = type(elem) - if isinstance(elem, torch.Tensor): - out = None - if torch.utils.data.get_worker_info() is not None: - # If we're in a background process, concatenate directly into a - # shared memory tensor to avoid an extra copy - numel = sum([x.numel() for x in batch]) - storage = elem.storage()._new_shared(numel) - out = elem.new(storage) - return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ - and elem_type.__name__ != 'string_': - if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': - # array of string classes and object - if np_str_obj_array_pattern.search(elem.dtype.str) is not None: - raise TypeError(default_collate_err_msg_format.format(elem.dtype)) - - return custom_collate([torch.as_tensor(b) for b in batch]) - elif elem.shape == (): # scalars - return torch.as_tensor(batch) - elif isinstance(elem, float): - return torch.tensor(batch, dtype=torch.float64) - elif isinstance(elem, int): - return torch.tensor(batch) - elif isinstance(elem, string_classes): - return batch - elif isinstance(elem, collections.abc.Mapping): - return {key: custom_collate([d[key] for d in batch]) for key in elem} - elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple - return elem_type(*(custom_collate(samples) for samples in zip(*batch))) - if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added - return batch # added - elif isinstance(elem, collections.abc.Sequence): - # check to make sure that the elements in batch have consistent size - it = iter(batch) - elem_size = len(next(it)) - if not all(len(elem) == elem_size for elem in it): - raise RuntimeError('each element in list of batch should be of equal size') - transposed = zip(*batch) - return [custom_collate(samples) for samples in transposed] - - raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/Control-Color/taming/lr_scheduler.py b/Control-Color/taming/lr_scheduler.py deleted file mode 100644 index e598ed120159c53da6820a55ad86b89f5c70c82d..0000000000000000000000000000000000000000 --- a/Control-Color/taming/lr_scheduler.py +++ /dev/null @@ -1,34 +0,0 @@ -import numpy as np - - -class LambdaWarmUpCosineScheduler: - """ - note: use with a base_lr of 1.0 - """ - def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): - self.lr_warm_up_steps = warm_up_steps - self.lr_start = lr_start - self.lr_min = lr_min - self.lr_max = lr_max - self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. - self.verbosity_interval = verbosity_interval - - def schedule(self, n): - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") - if n < self.lr_warm_up_steps: - lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start - self.last_lr = lr - return lr - else: - t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) - t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) - self.last_lr = lr - return lr - - def __call__(self, n): - return self.schedule(n) - diff --git a/Control-Color/taming/models/cond_transformer.py b/Control-Color/taming/models/cond_transformer.py deleted file mode 100644 index e4c63730fa86ac1b92b37af14c14fb696595b1ab..0000000000000000000000000000000000000000 --- a/Control-Color/taming/models/cond_transformer.py +++ /dev/null @@ -1,352 +0,0 @@ -import os, math -import torch -import torch.nn.functional as F -import pytorch_lightning as pl - -from main import instantiate_from_config -from taming.modules.util import SOSProvider - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class Net2NetTransformer(pl.LightningModule): - def __init__(self, - transformer_config, - first_stage_config, - cond_stage_config, - permuter_config=None, - ckpt_path=None, - ignore_keys=[], - first_stage_key="image", - cond_stage_key="depth", - downsample_cond_size=-1, - pkeep=1.0, - sos_token=0, - unconditional=False, - ): - super().__init__() - self.be_unconditional = unconditional - self.sos_token = sos_token - self.first_stage_key = first_stage_key - self.cond_stage_key = cond_stage_key - self.init_first_stage_from_ckpt(first_stage_config) - self.init_cond_stage_from_ckpt(cond_stage_config) - if permuter_config is None: - permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} - self.permuter = instantiate_from_config(config=permuter_config) - self.transformer = instantiate_from_config(config=transformer_config) - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.downsample_cond_size = downsample_cond_size - self.pkeep = pkeep - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - for k in sd.keys(): - for ik in ignore_keys: - if k.startswith(ik): - self.print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - def init_first_stage_from_ckpt(self, config): - model = instantiate_from_config(config) - model = model.eval() - model.train = disabled_train - self.first_stage_model = model - - def init_cond_stage_from_ckpt(self, config): - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") - self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__" or self.be_unconditional: - print(f"Using no cond stage. Assuming the training is intended to be unconditional. " - f"Prepending {self.sos_token} as a sos token.") - self.be_unconditional = True - self.cond_stage_key = self.first_stage_key - self.cond_stage_model = SOSProvider(self.sos_token) - else: - model = instantiate_from_config(config) - model = model.eval() - model.train = disabled_train - self.cond_stage_model = model - - def forward(self, x, c): - # one step to produce the logits - _, z_indices = self.encode_to_z(x) - _, c_indices = self.encode_to_c(c) - - if self.training and self.pkeep < 1.0: - mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, - device=z_indices.device)) - mask = mask.round().to(dtype=torch.int64) - r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) - a_indices = mask*z_indices+(1-mask)*r_indices - else: - a_indices = z_indices - - cz_indices = torch.cat((c_indices, a_indices), dim=1) - - # target includes all sequence elements (no need to handle first one - # differently because we are conditioning) - target = z_indices - # make the prediction - logits, _ = self.transformer(cz_indices[:, :-1]) - # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: - c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) - quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) - if len(indices.shape) > 2: - indices = indices.view(c.shape[0], -1) - return quant_c, indices - - @torch.no_grad() - def decode_to_img(self, index, zshape): - index = self.permuter(index, reverse=True) - bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) - quant_z = self.first_stage_model.quantize.get_codebook_entry( - index.reshape(-1), shape=bhwc) - x = self.first_stage_model.decode(quant_z) - return x - - @torch.no_grad() - def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): - log = dict() - - N = 4 - if lr_interface: - x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) - else: - x, c = self.get_xc(batch, N) - x = x.to(device=self.device) - c = c.to(device=self.device) - - quant_z, z_indices = self.encode_to_z(x) - quant_c, c_indices = self.encode_to_c(c) - - # create a "half"" sample - z_start_indices = z_indices[:,:z_indices.shape[1]//2] - index_sample = self.sample(z_start_indices, c_indices, - steps=z_indices.shape[1]-z_start_indices.shape[1], - temperature=temperature if temperature is not None else 1.0, - sample=True, - top_k=top_k if top_k is not None else 100, - callback=callback if callback is not None else lambda k: None) - x_sample = self.decode_to_img(index_sample, quant_z.shape) - - # sample - z_start_indices = z_indices[:, :0] - index_sample = self.sample(z_start_indices, c_indices, - steps=z_indices.shape[1], - temperature=temperature if temperature is not None else 1.0, - sample=True, - top_k=top_k if top_k is not None else 100, - callback=callback if callback is not None else lambda k: None) - x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) - - # det sample - z_start_indices = z_indices[:, :0] - index_sample = self.sample(z_start_indices, c_indices, - steps=z_indices.shape[1], - sample=False, - callback=callback if callback is not None else lambda k: None) - x_sample_det = self.decode_to_img(index_sample, quant_z.shape) - - # reconstruction - x_rec = self.decode_to_img(z_indices, quant_z.shape) - - log["inputs"] = x - log["reconstructions"] = x_rec - - if self.cond_stage_key in ["objects_bbox", "objects_center_points"]: - figure_size = (x_rec.shape[2], x_rec.shape[3]) - dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"] - label_for_category_no = dataset.get_textual_label_for_category_no - plotter = dataset.conditional_builders[self.cond_stage_key].plot - log["conditioning"] = torch.zeros_like(log["reconstructions"]) - for i in range(quant_c.shape[0]): - log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size) - log["conditioning_rec"] = log["conditioning"] - elif self.cond_stage_key != "image": - cond_rec = self.cond_stage_model.decode(quant_c) - if self.cond_stage_key == "segmentation": - # get image from segmentation mask - num_classes = cond_rec.shape[1] - - c = torch.argmax(c, dim=1, keepdim=True) - c = F.one_hot(c, num_classes=num_classes) - c = c.squeeze(1).permute(0, 3, 1, 2).float() - c = self.cond_stage_model.to_rgb(c) - - cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) - cond_rec = F.one_hot(cond_rec, num_classes=num_classes) - cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() - cond_rec = self.cond_stage_model.to_rgb(cond_rec) - log["conditioning_rec"] = cond_rec - log["conditioning"] = c - - log["samples_half"] = x_sample - log["samples_nopix"] = x_sample_nopix - log["samples_det"] = x_sample_det - return log - - def get_input(self, key, batch): - x = batch[key] - if len(x.shape) == 3: - x = x[..., None] - if len(x.shape) == 4: - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) - if x.dtype == torch.double: - x = x.float() - return x - - def get_xc(self, batch, N=None): - x = self.get_input(self.first_stage_key, batch) - c = self.get_input(self.cond_stage_key, batch) - if N is not None: - x = x[:N] - c = c[:N] - return x, c - - def shared_step(self, batch, batch_idx): - x, c = self.get_xc(batch) - logits, target = self(x, c) - loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) - return loss - - def training_step(self, batch, batch_idx): - loss = self.shared_step(batch, batch_idx) - self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - return loss - - def validation_step(self, batch, batch_idx): - loss = self.shared_step(batch, batch_idx) - self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - return loss - - def configure_optimizers(self): - """ - Following minGPT: - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. - """ - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - whitelist_weight_modules = (torch.nn.Linear, ) - blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) - for mn, m in self.transformer.named_modules(): - for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - - if pn.endswith('bias'): - # all biases will not be decayed - no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - # special case the position embedding parameter in the root GPT module as not decayed - no_decay.add('pos_emb') - - # validate that we considered every parameter - param_dict = {pn: p for pn, p in self.transformer.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) - - # create the pytorch optimizer object - optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) - return optimizer diff --git a/Control-Color/taming/models/dummy_cond_stage.py b/Control-Color/taming/models/dummy_cond_stage.py deleted file mode 100644 index 6e19938078752e09b926a3e749907ee99a258ca0..0000000000000000000000000000000000000000 --- a/Control-Color/taming/models/dummy_cond_stage.py +++ /dev/null @@ -1,22 +0,0 @@ -from torch import Tensor - - -class DummyCondStage: - def __init__(self, conditional_key): - self.conditional_key = conditional_key - self.train = None - - def eval(self): - return self - - @staticmethod - def encode(c: Tensor): - return c, None, (None, None, c) - - @staticmethod - def decode(c: Tensor): - return c - - @staticmethod - def to_rgb(c: Tensor): - return c diff --git a/Control-Color/taming/models/vqgan.py b/Control-Color/taming/models/vqgan.py deleted file mode 100644 index a6950baa5f739111cd64c17235dca8be3a5f8037..0000000000000000000000000000000000000000 --- a/Control-Color/taming/models/vqgan.py +++ /dev/null @@ -1,404 +0,0 @@ -import torch -import torch.nn.functional as F -import pytorch_lightning as pl - -from main import instantiate_from_config - -from taming.modules.diffusionmodules.model import Encoder, Decoder -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer -from taming.modules.vqvae.quantize import GumbelQuantize -from taming.modules.vqvae.quantize import EMAVectorQuantizer - -class VQModel(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ): - super().__init__() - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, sane_index_shape=sane_index_shape) - self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.image_key = image_key - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - quant, emb_loss, info = self.quantize(h) - return quant, emb_loss, info - - def decode(self, quant): - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - def decode_code(self, code_b): - quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec - - def forward(self, input): - quant, diff, _ = self.encode(input) - dec = self.decode(quant) - return dec, diff - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) - return x.float() - - def training_step(self, batch, batch_idx, optimizer_idx): - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return discloss - - def validation_step(self, batch, batch_idx): - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") - - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") - rec_loss = log_dict_ae["val/rec_loss"] - self.log("val/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) - self.log("val/aeloss", aeloss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - def log_images(self, batch, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - xrec, _ = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. - return x - - -class VQSegmentationModel(VQModel): - def __init__(self, n_labels, *args, **kwargs): - super().__init__(*args, **kwargs) - self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) - - def configure_optimizers(self): - lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - return opt_ae - - def training_step(self, batch, batch_idx): - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return aeloss - - def validation_step(self, batch, batch_idx): - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - total_loss = log_dict_ae["val/total_loss"] - self.log("val/total_loss", total_loss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) - return aeloss - - @torch.no_grad() - def log_images(self, batch, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - xrec, _ = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - # convert logits to indices - xrec = torch.argmax(xrec, dim=1, keepdim=True) - xrec = F.one_hot(xrec, num_classes=x.shape[1]) - xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec - return log - - -class VQNoDiscModel(VQModel): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None - ): - super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, - ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, - colorize_nlabels=colorize_nlabels) - - def training_step(self, batch, batch_idx): - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x) - # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") - output = pl.TrainResult(minimize=aeloss) - output.log("train/aeloss", aeloss, - prog_bar=True, logger=True, on_step=True, on_epoch=True) - output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return output - - def validation_step(self, batch, batch_idx): - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") - rec_loss = log_dict_ae["val/rec_loss"] - output = pl.EvalResult(checkpoint_on=rec_loss) - output.log("val/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=True, on_epoch=True) - output.log("val/aeloss", aeloss, - prog_bar=True, logger=True, on_step=True, on_epoch=True) - output.log_dict(log_dict_ae) - - return output - - def configure_optimizers(self): - optimizer = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=self.learning_rate, betas=(0.5, 0.9)) - return optimizer - - -class GumbelVQ(VQModel): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - temperature_scheduler_config, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - kl_weight=1e-8, - remap=None, - ): - - z_channels = ddconfig["z_channels"] - super().__init__(ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=ignore_keys, - image_key=image_key, - colorize_nlabels=colorize_nlabels, - monitor=monitor, - ) - - self.loss.n_classes = n_embed - self.vocab_size = n_embed - - self.quantize = GumbelQuantize(z_channels, embed_dim, - n_embed=n_embed, - kl_weight=kl_weight, temp_init=1.0, - remap=remap) - - self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def temperature_scheduling(self): - self.quantize.temperature = self.temperature_scheduler(self.global_step) - - def encode_to_prequant(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode_code(self, code_b): - raise NotImplementedError - - def training_step(self, batch, batch_idx, optimizer_idx): - self.temperature_scheduling() - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return discloss - - def validation_step(self, batch, batch_idx): - x = self.get_input(batch, self.image_key) - xrec, qloss = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") - - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") - rec_loss = log_dict_ae["val/rec_loss"] - self.log("val/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log("val/aeloss", aeloss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def log_images(self, batch, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - # encode - h = self.encoder(x) - h = self.quant_conv(h) - quant, _, _ = self.quantize(h) - # decode - x_rec = self.decode(quant) - log["inputs"] = x - log["reconstructions"] = x_rec - return log - - -class EMAVQ(VQModel): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ): - super().__init__(ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=ignore_keys, - image_key=image_key, - colorize_nlabels=colorize_nlabels, - monitor=monitor, - ) - self.quantize = EMAVectorQuantizer(n_embed=n_embed, - embedding_dim=embed_dim, - beta=0.25, - remap=remap) - def configure_optimizers(self): - lr = self.learning_rate - #Remove self.quantize from parameter list since it is updated via EMA - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] \ No newline at end of file diff --git a/Control-Color/taming/modules/__pycache__/util.cpython-38.pyc b/Control-Color/taming/modules/__pycache__/util.cpython-38.pyc deleted file mode 100644 index 4e210ae6ceae393267ada9b209c13073a5703c1b..0000000000000000000000000000000000000000 Binary files a/Control-Color/taming/modules/__pycache__/util.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/taming/modules/autoencoder/lpips/vgg.pth b/Control-Color/taming/modules/autoencoder/lpips/vgg.pth deleted file mode 100644 index f57dcf5cc764d61c8a460365847fb2137ff0a62d..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/autoencoder/lpips/vgg.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 -size 7289 diff --git a/Control-Color/taming/modules/diffusionmodules/model.py b/Control-Color/taming/modules/diffusionmodules/model.py deleted file mode 100644 index d3a5db6aa2ef915e270f1ae135e4a9918fdd884c..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/diffusionmodules/model.py +++ /dev/null @@ -1,776 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import numpy as np - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) - return emb - - -def nonlinearity(x): - # swish - return x*torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x+h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) - - h_ = self.proj_out(h_) - - return x+h_ - - -class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True): - super().__init__() - self.ch = ch - self.temb_ch = self.ch*4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - - def forward(self, x, t=None): - #assert x.shape[2] == x.shape[3] == self.resolution - - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, **ignore_kwargs): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) - - - def forward(self, x): - #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, **ignorekwargs): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class VUNet(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - in_channels, c_channels, - resolution, z_channels, use_timestep=False, **ignore_kwargs): - super().__init__() - self.ch = ch - self.temb_ch = self.ch*4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) - - # downsampling - self.conv_in = torch.nn.Conv2d(c_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - self.z_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=1, - stride=1, - padding=0) - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=2*block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - - def forward(self, x, z): - #assert x.shape[2] == x.shape[3] == self.resolution - - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - z = self.z_in(z) - h = torch.cat((h,z),dim=1) - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - diff --git a/Control-Color/taming/modules/discriminator/__pycache__/model.cpython-38.pyc b/Control-Color/taming/modules/discriminator/__pycache__/model.cpython-38.pyc deleted file mode 100644 index 1060863e45ba22778301440300b88265823e91a0..0000000000000000000000000000000000000000 Binary files a/Control-Color/taming/modules/discriminator/__pycache__/model.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/taming/modules/discriminator/model.py b/Control-Color/taming/modules/discriminator/model.py deleted file mode 100644 index 2aaa3110d0a7bcd05de7eca1e45101589ca5af05..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/discriminator/model.py +++ /dev/null @@ -1,67 +0,0 @@ -import functools -import torch.nn as nn - - -from taming.modules.util import ActNorm - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) - - -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator as in Pix2Pix - --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py - """ - def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ - super(NLayerDiscriminator, self).__init__() - if not use_actnorm: - norm_layer = nn.BatchNorm2d - else: - norm_layer = ActNorm - if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters - use_bias = norm_layer.func != nn.BatchNorm2d - else: - use_bias = norm_layer != nn.BatchNorm2d - - kw = 4 - padw = 1 - sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] - nf_mult = 1 - nf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - nf_mult_prev = nf_mult - nf_mult = min(2 ** n, 8) - sequence += [ - nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - nf_mult_prev = nf_mult - nf_mult = min(2 ** n_layers, 8) - sequence += [ - nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map - self.main = nn.Sequential(*sequence) - - def forward(self, input): - """Standard forward.""" - return self.main(input) diff --git a/Control-Color/taming/modules/losses/__init__.py b/Control-Color/taming/modules/losses/__init__.py deleted file mode 100644 index d09caf9eb805f849a517f1b23503e1a4d6ea1ec5..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/losses/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from taming.modules.losses.vqperceptual import DummyLoss - diff --git a/Control-Color/taming/modules/losses/__pycache__/__init__.cpython-38.pyc b/Control-Color/taming/modules/losses/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index b81df1475a7a278598c4f0ab46bf9ca6cbfa84b7..0000000000000000000000000000000000000000 Binary files a/Control-Color/taming/modules/losses/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/taming/modules/losses/__pycache__/lpips.cpython-38.pyc b/Control-Color/taming/modules/losses/__pycache__/lpips.cpython-38.pyc deleted file mode 100644 index b255caecb231e31f0ca0707c9bcb1dcfb9435bc5..0000000000000000000000000000000000000000 Binary files a/Control-Color/taming/modules/losses/__pycache__/lpips.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc b/Control-Color/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc deleted file mode 100644 index f7a6b635260b259d85f3f6bb70d51881dfbefc66..0000000000000000000000000000000000000000 Binary files a/Control-Color/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc and /dev/null differ diff --git a/Control-Color/taming/modules/losses/lpips.py b/Control-Color/taming/modules/losses/lpips.py deleted file mode 100644 index a7280447694ffc302a7636e7e4d6183408e0aa95..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/losses/lpips.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" - -import torch -import torch.nn as nn -from torchvision import models -from collections import namedtuple - -from taming.util import get_ckpt_path - - -class LPIPS(nn.Module): - # Learned perceptual metric - def __init__(self, use_dropout=True): - super().__init__() - self.scaling_layer = ScalingLayer() - self.chns = [64, 128, 256, 512, 512] # vg16 features - self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - self.load_from_pretrained() - for param in self.parameters(): - param.requires_grad = False - - def load_from_pretrained(self, name="vgg_lpips"): - ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") - self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) - print("loaded pretrained LPIPS loss from {}".format(ckpt)) - - @classmethod - def from_pretrained(cls, name="vgg_lpips"): - if name != "vgg_lpips": - raise NotImplementedError - model = cls() - ckpt = get_ckpt_path(name) - model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) - return model - - def forward(self, input, target): - in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) - outs0, outs1 = self.net(in0_input), self.net(in1_input) - feats0, feats1, diffs = {}, {}, {} - lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] - for kk in range(len(self.chns)): - feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) - diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - - res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] - val = res[0] - for l in range(1, len(self.chns)): - val += res[l] - return val - - -class ScalingLayer(nn.Module): - def __init__(self): - super(ScalingLayer, self).__init__() - self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) - self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) - - def forward(self, inp): - return (inp - self.shift) / self.scale - - -class NetLinLayer(nn.Module): - """ A single linear layer which does a 1x1 conv """ - def __init__(self, chn_in, chn_out=1, use_dropout=False): - super(NetLinLayer, self).__init__() - layers = [nn.Dropout(), ] if (use_dropout) else [] - layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] - self.model = nn.Sequential(*layers) - - -class vgg16(torch.nn.Module): - def __init__(self, requires_grad=False, pretrained=True): - super(vgg16, self).__init__() - vgg_pretrained_features = models.vgg16(pretrained=pretrained).features - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - self.N_slices = 5 - for x in range(4): - self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(4, 9): - self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(9, 16): - self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(16, 23): - self.slice4.add_module(str(x), vgg_pretrained_features[x]) - for x in range(23, 30): - self.slice5.add_module(str(x), vgg_pretrained_features[x]) - if not requires_grad: - for param in self.parameters(): - param.requires_grad = False - - def forward(self, X): - h = self.slice1(X) - h_relu1_2 = h - h = self.slice2(h) - h_relu2_2 = h - h = self.slice3(h) - h_relu3_3 = h - h = self.slice4(h) - h_relu4_3 = h - h = self.slice5(h) - h_relu5_3 = h - vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) - out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) - return out - - -def normalize_tensor(x,eps=1e-10): - norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) - return x/(norm_factor+eps) - - -def spatial_average(x, keepdim=True): - return x.mean([2,3],keepdim=keepdim) - diff --git a/Control-Color/taming/modules/losses/segmentation.py b/Control-Color/taming/modules/losses/segmentation.py deleted file mode 100644 index 4ba77deb5159a6307ed2acba9945e4764a4ff0a5..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/losses/segmentation.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class BCELoss(nn.Module): - def forward(self, prediction, target): - loss = F.binary_cross_entropy_with_logits(prediction,target) - return loss, {} - - -class BCELossWithQuant(nn.Module): - def __init__(self, codebook_weight=1.): - super().__init__() - self.codebook_weight = codebook_weight - - def forward(self, qloss, target, prediction, split): - bce_loss = F.binary_cross_entropy_with_logits(prediction,target) - loss = bce_loss + self.codebook_weight*qloss - return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/bce_loss".format(split): bce_loss.detach().mean(), - "{}/quant_loss".format(split): qloss.detach().mean() - } diff --git a/Control-Color/taming/modules/losses/vqperceptual.py b/Control-Color/taming/modules/losses/vqperceptual.py deleted file mode 100644 index 488477782b505b3a8e463bc1badb4d0ac85fdcc7..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/losses/vqperceptual.py +++ /dev/null @@ -1,241 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from taming.modules.losses.lpips import LPIPS -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init - - -class DummyLoss(nn.Module): - def __init__(self): - super().__init__() - - -def adopt_weight(weight, global_step, threshold=0, value=0.): - if global_step < threshold: - weight = value - return weight - - -def hinge_d_loss(logits_real, logits_fake): - loss_real = torch.mean(F.relu(1. - logits_real)) - loss_fake = torch.mean(F.relu(1. + logits_fake)) - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - - -def vanilla_d_loss(logits_real, logits_fake): - d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) + - torch.mean(torch.nn.functional.softplus(logits_fake))) - return d_loss - - -class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge"): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.codebook_weight = codebook_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) - self.discriminator_iter_start = disc_start - if disc_loss == "hinge": - self.disc_loss = hinge_d_loss - elif disc_loss == "vanilla": - self.disc_loss = vanilla_d_loss - else: - raise ValueError(f"Unknown GAN loss '{disc_loss}'.") - print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train"): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - else: - p_loss = torch.tensor([0.0]) - - nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - nll_loss = torch.mean(nll_loss) - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log - -class LPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_loss="hinge"): - - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - # output log variance - self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm - ).apply(weights_init) - self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, inputs, reconstructions, posteriors, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", - weights=None): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - - nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights*nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - kl_loss = posteriors.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - if self.disc_factor > 0.0: - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - else: - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log \ No newline at end of file diff --git a/Control-Color/taming/modules/misc/coord.py b/Control-Color/taming/modules/misc/coord.py deleted file mode 100644 index ee69b0c897b6b382ae673622e420f55e494f5b09..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/misc/coord.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -class CoordStage(object): - def __init__(self, n_embed, down_factor): - self.n_embed = n_embed - self.down_factor = down_factor - - def eval(self): - return self - - def encode(self, c): - """fake vqmodel interface""" - assert 0.0 <= c.min() and c.max() <= 1.0 - b,ch,h,w = c.shape - assert ch == 1 - - c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, - mode="area") - c = c.clamp(0.0, 1.0) - c = self.n_embed*c - c_quant = c.round() - c_ind = c_quant.to(dtype=torch.long) - - info = None, None, c_ind - return c_quant, None, info - - def decode(self, c): - c = c/self.n_embed - c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, - mode="nearest") - return c diff --git a/Control-Color/taming/modules/transformer/mingpt.py b/Control-Color/taming/modules/transformer/mingpt.py deleted file mode 100644 index d14b7b68117f4b9f297b2929397cd4f55089334c..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/transformer/mingpt.py +++ /dev/null @@ -1,415 +0,0 @@ -""" -taken from: https://github.com/karpathy/minGPT/ -GPT model: -- the initial stem consists of a combination of token encoding and a positional encoding -- the meat of it is a uniform sequence of Transformer blocks - - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block - - all blocks feed into a central residual pathway similar to resnets -- the final decoder is a linear projection into a vanilla Softmax classifier -""" - -import math -import logging - -import torch -import torch.nn as nn -from torch.nn import functional as F -from transformers import top_k_top_p_filtering - -logger = logging.getLogger(__name__) - - -class GPTConfig: - """ base GPT config, params common to all GPT versions """ - embd_pdrop = 0.1 - resid_pdrop = 0.1 - attn_pdrop = 0.1 - - def __init__(self, vocab_size, block_size, **kwargs): - self.vocab_size = vocab_size - self.block_size = block_size - for k,v in kwargs.items(): - setattr(self, k, v) - - -class GPT1Config(GPTConfig): - """ GPT-1 like network roughly 125M params """ - n_layer = 12 - n_head = 12 - n_embd = 768 - - -class CausalSelfAttention(nn.Module): - """ - A vanilla multi-head masked self-attention layer with a projection at the end. - It is possible to use torch.nn.MultiheadAttention here but I am including an - explicit implementation here to show that there is nothing too scary here. - """ - - def __init__(self, config): - super().__init__() - assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads - self.key = nn.Linear(config.n_embd, config.n_embd) - self.query = nn.Linear(config.n_embd, config.n_embd) - self.value = nn.Linear(config.n_embd, config.n_embd) - # regularization - self.attn_drop = nn.Dropout(config.attn_pdrop) - self.resid_drop = nn.Dropout(config.resid_pdrop) - # output projection - self.proj = nn.Linear(config.n_embd, config.n_embd) - # causal mask to ensure that attention is only applied to the left in the input sequence - mask = torch.tril(torch.ones(config.block_size, - config.block_size)) - if hasattr(config, "n_unmasked"): - mask[:config.n_unmasked, :config.n_unmasked] = 1 - self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) - self.n_head = config.n_head - - def forward(self, x, layer_past=None): - B, T, C = x.size() - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - - present = torch.stack((k, v)) - if layer_past is not None: - past_key, past_value = layer_past - k = torch.cat((past_key, k), dim=-2) - v = torch.cat((past_value, v), dim=-2) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - if layer_past is None: - att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) - - att = F.softmax(att, dim=-1) - att = self.attn_drop(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - - # output projection - y = self.resid_drop(self.proj(y)) - return y, present # TODO: check that this does not break anything - - -class Block(nn.Module): - """ an unassuming Transformer block """ - def __init__(self, config): - super().__init__() - self.ln1 = nn.LayerNorm(config.n_embd) - self.ln2 = nn.LayerNorm(config.n_embd) - self.attn = CausalSelfAttention(config) - self.mlp = nn.Sequential( - nn.Linear(config.n_embd, 4 * config.n_embd), - nn.GELU(), # nice - nn.Linear(4 * config.n_embd, config.n_embd), - nn.Dropout(config.resid_pdrop), - ) - - def forward(self, x, layer_past=None, return_present=False): - # TODO: check that training still works - if return_present: assert not self.training - # layer past: tuple of length two with B, nh, T, hs - attn, present = self.attn(self.ln1(x), layer_past=layer_past) - - x = x + attn - x = x + self.mlp(self.ln2(x)) - if layer_past is not None or return_present: - return x, present - return x - - -class GPT(nn.Module): - """ the full GPT language model, with a context size of block_size """ - def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, - embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): - super().__init__() - config = GPTConfig(vocab_size=vocab_size, block_size=block_size, - embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, - n_layer=n_layer, n_head=n_head, n_embd=n_embd, - n_unmasked=n_unmasked) - # input embedding stem - self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) - self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) - self.drop = nn.Dropout(config.embd_pdrop) - # transformer - self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) - # decoder head - self.ln_f = nn.LayerNorm(config.n_embd) - self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.block_size = config.block_size - self.apply(self._init_weights) - self.config = config - logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) - - def get_block_size(self): - return self.block_size - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, idx, embeddings=None, targets=None): - # forward the GPT model - token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector - - if embeddings is not None: # prepend explicit embeddings - token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) - - t = token_embeddings.shape[1] - assert t <= self.block_size, "Cannot forward, model block size is exhausted." - position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector - x = self.drop(token_embeddings + position_embeddings) - x = self.blocks(x) - x = self.ln_f(x) - logits = self.head(x) - - # if we are given some desired targets also calculate the loss - loss = None - if targets is not None: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) - - return logits, loss - - def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None): - # inference only - assert not self.training - token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector - if embeddings is not None: # prepend explicit embeddings - token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) - - if past is not None: - assert past_length is not None - past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head - past_shape = list(past.shape) - expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] - assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" - position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector - else: - position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] - - x = self.drop(token_embeddings + position_embeddings) - presents = [] # accumulate over layers - for i, block in enumerate(self.blocks): - x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True) - presents.append(present) - - x = self.ln_f(x) - logits = self.head(x) - # if we are given some desired targets also calculate the loss - loss = None - if targets is not None: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) - - return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head - - -class DummyGPT(nn.Module): - # for debugging - def __init__(self, add_value=1): - super().__init__() - self.add_value = add_value - - def forward(self, idx): - return idx + self.add_value, None - - -class CodeGPT(nn.Module): - """Takes in semi-embeddings""" - def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, - embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): - super().__init__() - config = GPTConfig(vocab_size=vocab_size, block_size=block_size, - embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, - n_layer=n_layer, n_head=n_head, n_embd=n_embd, - n_unmasked=n_unmasked) - # input embedding stem - self.tok_emb = nn.Linear(in_channels, config.n_embd) - self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) - self.drop = nn.Dropout(config.embd_pdrop) - # transformer - self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) - # decoder head - self.ln_f = nn.LayerNorm(config.n_embd) - self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.block_size = config.block_size - self.apply(self._init_weights) - self.config = config - logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) - - def get_block_size(self): - return self.block_size - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, idx, embeddings=None, targets=None): - # forward the GPT model - token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector - - if embeddings is not None: # prepend explicit embeddings - token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) - - t = token_embeddings.shape[1] - assert t <= self.block_size, "Cannot forward, model block size is exhausted." - position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector - x = self.drop(token_embeddings + position_embeddings) - x = self.blocks(x) - x = self.taming_cinln_f(x) - logits = self.head(x) - - # if we are given some desired targets also calculate the loss - loss = None - if targets is not None: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) - - return logits, loss - - - -#### sampling utils - -def top_k_logits(logits, k): - v, ix = torch.topk(logits, k) - out = logits.clone() - out[out < v[:, [-1]]] = -float('Inf') - return out - -@torch.no_grad() -def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): - """ - take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in - the sequence, feeding the predictions back into the model each time. Clearly the sampling - has quadratic complexity unlike an RNN that is only linear, and has a finite context window - of block_size, unlike an RNN that has an infinite context window. - """ - block_size = model.get_block_size() - model.eval() - for k in range(steps): - x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed - logits, _ = model(x_cond) - # pluck the logits at the final step and scale by temperature - logits = logits[:, -1, :] / temperature - # optionally crop probabilities to only the top k options - if top_k is not None: - logits = top_k_logits(logits, top_k) - # apply softmax to convert to probabilities - probs = F.softmax(logits, dim=-1) - # sample from the distribution or take the most likely - if sample: - ix = torch.multinomial(probs, num_samples=1) - else: - _, ix = torch.topk(probs, k=1, dim=-1) - # append to the sequence and continue - x = torch.cat((x, ix), dim=1) - - return x - - -@torch.no_grad() -def sample_with_past(x, model, steps, temperature=1., sample_logits=True, - top_k=None, top_p=None, callback=None): - # x is conditioning - sample = x - cond_len = x.shape[1] - past = None - for n in range(steps): - if callback is not None: - callback(n) - logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1)) - if past is None: - past = [present] - else: - past.append(present) - logits = logits[:, -1, :] / temperature - if top_k is not None: - logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) - - probs = F.softmax(logits, dim=-1) - if not sample_logits: - _, x = torch.topk(probs, k=1, dim=-1) - else: - x = torch.multinomial(probs, num_samples=1) - # append to the sequence and continue - sample = torch.cat((sample, x), dim=1) - del past - sample = sample[:, cond_len:] # cut conditioning off - return sample - - -#### clustering utils - -class KMeans(nn.Module): - def __init__(self, ncluster=512, nc=3, niter=10): - super().__init__() - self.ncluster = ncluster - self.nc = nc - self.niter = niter - self.shape = (3,32,32) - self.register_buffer("C", torch.zeros(self.ncluster,nc)) - self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) - - def is_initialized(self): - return self.initialized.item() == 1 - - @torch.no_grad() - def initialize(self, x): - N, D = x.shape - assert D == self.nc, D - c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random - for i in range(self.niter): - # assign all pixels to the closest codebook element - a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1) - # move each codebook element to be the mean of the pixels that assigned to it - c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)]) - # re-assign any poorly positioned codebook elements - nanix = torch.any(torch.isnan(c), dim=1) - ndead = nanix.sum().item() - print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead)) - c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters - - self.C.copy_(c) - self.initialized.fill_(1) - - - def forward(self, x, reverse=False, shape=None): - if not reverse: - # flatten - bs,c,h,w = x.shape - assert c == self.nc - x = x.reshape(bs,c,h*w,1) - C = self.C.permute(1,0) - C = C.reshape(1,c,1,self.ncluster) - a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices - return a - else: - # flatten - bs, HW = x.shape - """ - c = self.C.reshape( 1, self.nc, 1, self.ncluster) - c = c[bs*[0],:,:,:] - c = c[:,:,HW*[0],:] - x = x.reshape(bs, 1, HW, 1) - x = x[:,3*[0],:,:] - x = torch.gather(c, dim=3, index=x) - """ - x = self.C[x] - x = x.permute(0,2,1) - shape = shape if shape is not None else self.shape - x = x.reshape(bs, *shape) - - return x diff --git a/Control-Color/taming/modules/transformer/permuter.py b/Control-Color/taming/modules/transformer/permuter.py deleted file mode 100644 index 0d43bb135adde38d94bf18a7e5edaa4523cd95cf..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/transformer/permuter.py +++ /dev/null @@ -1,248 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np - - -class AbstractPermuter(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - def forward(self, x, reverse=False): - raise NotImplementedError - - -class Identity(AbstractPermuter): - def __init__(self): - super().__init__() - - def forward(self, x, reverse=False): - return x - - -class Subsample(AbstractPermuter): - def __init__(self, H, W): - super().__init__() - C = 1 - indices = np.arange(H*W).reshape(C,H,W) - while min(H, W) > 1: - indices = indices.reshape(C,H//2,2,W//2,2) - indices = indices.transpose(0,2,4,1,3) - indices = indices.reshape(C*4,H//2, W//2) - H = H//2 - W = W//2 - C = C*4 - assert H == W == 1 - idx = torch.tensor(indices.ravel()) - self.register_buffer('forward_shuffle_idx', - nn.Parameter(idx, requires_grad=False)) - self.register_buffer('backward_shuffle_idx', - nn.Parameter(torch.argsort(idx), requires_grad=False)) - - def forward(self, x, reverse=False): - if not reverse: - return x[:, self.forward_shuffle_idx] - else: - return x[:, self.backward_shuffle_idx] - - -def mortonify(i, j): - """(i,j) index to linear morton code""" - i = np.uint64(i) - j = np.uint64(j) - - z = np.uint(0) - - for pos in range(32): - z = (z | - ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | - ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) - ) - return z - - -class ZCurve(AbstractPermuter): - def __init__(self, H, W): - super().__init__() - reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] - idx = np.argsort(reverseidx) - idx = torch.tensor(idx) - reverseidx = torch.tensor(reverseidx) - self.register_buffer('forward_shuffle_idx', - idx) - self.register_buffer('backward_shuffle_idx', - reverseidx) - - def forward(self, x, reverse=False): - if not reverse: - return x[:, self.forward_shuffle_idx] - else: - return x[:, self.backward_shuffle_idx] - - -class SpiralOut(AbstractPermuter): - def __init__(self, H, W): - super().__init__() - assert H == W - size = W - indices = np.arange(size*size).reshape(size,size) - - i0 = size//2 - j0 = size//2-1 - - i = i0 - j = j0 - - idx = [indices[i0, j0]] - step_mult = 0 - for c in range(1, size//2+1): - step_mult += 1 - # steps left - for k in range(step_mult): - i = i - 1 - j = j - idx.append(indices[i, j]) - - # step down - for k in range(step_mult): - i = i - j = j + 1 - idx.append(indices[i, j]) - - step_mult += 1 - if c < size//2: - # step right - for k in range(step_mult): - i = i + 1 - j = j - idx.append(indices[i, j]) - - # step up - for k in range(step_mult): - i = i - j = j - 1 - idx.append(indices[i, j]) - else: - # end reached - for k in range(step_mult-1): - i = i + 1 - idx.append(indices[i, j]) - - assert len(idx) == size*size - idx = torch.tensor(idx) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) - - def forward(self, x, reverse=False): - if not reverse: - return x[:, self.forward_shuffle_idx] - else: - return x[:, self.backward_shuffle_idx] - - -class SpiralIn(AbstractPermuter): - def __init__(self, H, W): - super().__init__() - assert H == W - size = W - indices = np.arange(size*size).reshape(size,size) - - i0 = size//2 - j0 = size//2-1 - - i = i0 - j = j0 - - idx = [indices[i0, j0]] - step_mult = 0 - for c in range(1, size//2+1): - step_mult += 1 - # steps left - for k in range(step_mult): - i = i - 1 - j = j - idx.append(indices[i, j]) - - # step down - for k in range(step_mult): - i = i - j = j + 1 - idx.append(indices[i, j]) - - step_mult += 1 - if c < size//2: - # step right - for k in range(step_mult): - i = i + 1 - j = j - idx.append(indices[i, j]) - - # step up - for k in range(step_mult): - i = i - j = j - 1 - idx.append(indices[i, j]) - else: - # end reached - for k in range(step_mult-1): - i = i + 1 - idx.append(indices[i, j]) - - assert len(idx) == size*size - idx = idx[::-1] - idx = torch.tensor(idx) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) - - def forward(self, x, reverse=False): - if not reverse: - return x[:, self.forward_shuffle_idx] - else: - return x[:, self.backward_shuffle_idx] - - -class Random(nn.Module): - def __init__(self, H, W): - super().__init__() - indices = np.random.RandomState(1).permutation(H*W) - idx = torch.tensor(indices.ravel()) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) - - def forward(self, x, reverse=False): - if not reverse: - return x[:, self.forward_shuffle_idx] - else: - return x[:, self.backward_shuffle_idx] - - -class AlternateParsing(AbstractPermuter): - def __init__(self, H, W): - super().__init__() - indices = np.arange(W*H).reshape(H,W) - for i in range(1, H, 2): - indices[i, :] = indices[i, ::-1] - idx = indices.flatten() - assert len(idx) == H*W - idx = torch.tensor(idx) - self.register_buffer('forward_shuffle_idx', idx) - self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) - - def forward(self, x, reverse=False): - if not reverse: - return x[:, self.forward_shuffle_idx] - else: - return x[:, self.backward_shuffle_idx] - - -if __name__ == "__main__": - p0 = AlternateParsing(16, 16) - print(p0.forward_shuffle_idx) - print(p0.backward_shuffle_idx) - - x = torch.randint(0, 768, size=(11, 256)) - y = p0(x) - xre = p0(y, reverse=True) - assert torch.equal(x, xre) - - p1 = SpiralOut(2, 2) - print(p1.forward_shuffle_idx) - print(p1.backward_shuffle_idx) diff --git a/Control-Color/taming/modules/util.py b/Control-Color/taming/modules/util.py deleted file mode 100644 index 9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/util.py +++ /dev/null @@ -1,130 +0,0 @@ -import torch -import torch.nn as nn - - -def count_params(model): - total_params = sum(p.numel() for p in model.parameters()) - return total_params - - -class ActNorm(nn.Module): - def __init__(self, num_features, logdet=False, affine=True, - allow_reverse_init=False): - assert affine - super().__init__() - self.logdet = logdet - self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) - self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) - self.allow_reverse_init = allow_reverse_init - - self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) - - def initialize(self, input): - with torch.no_grad(): - flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) - mean = ( - flatten.mean(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - std = ( - flatten.std(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - - self.loc.data.copy_(-mean) - self.scale.data.copy_(1 / (std + 1e-6)) - - def forward(self, input, reverse=False): - if reverse: - return self.reverse(input) - if len(input.shape) == 2: - input = input[:,:,None,None] - squeeze = True - else: - squeeze = False - - _, _, height, width = input.shape - - if self.training and self.initialized.item() == 0: - self.initialize(input) - self.initialized.fill_(1) - - h = self.scale * (input + self.loc) - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - - if self.logdet: - log_abs = torch.log(torch.abs(self.scale)) - logdet = height*width*torch.sum(log_abs) - logdet = logdet * torch.ones(input.shape[0]).to(input) - return h, logdet - - return h - - def reverse(self, output): - if self.training and self.initialized.item() == 0: - if not self.allow_reverse_init: - raise RuntimeError( - "Initializing ActNorm in reverse direction is " - "disabled by default. Use allow_reverse_init=True to enable." - ) - else: - self.initialize(output) - self.initialized.fill_(1) - - if len(output.shape) == 2: - output = output[:,:,None,None] - squeeze = True - else: - squeeze = False - - h = output / self.scale - self.loc - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - return h - - -class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - - def encode(self, *args, **kwargs): - raise NotImplementedError - - -class Labelator(AbstractEncoder): - """Net2Net Interface for Class-Conditional Model""" - def __init__(self, n_classes, quantize_interface=True): - super().__init__() - self.n_classes = n_classes - self.quantize_interface = quantize_interface - - def encode(self, c): - c = c[:,None] - if self.quantize_interface: - return c, None, [None, None, c.long()] - return c - - -class SOSProvider(AbstractEncoder): - # for unconditional training - def __init__(self, sos_token, quantize_interface=True): - super().__init__() - self.sos_token = sos_token - self.quantize_interface = quantize_interface - - def encode(self, x): - # get batch size from data and replicate sos_token - c = torch.ones(x.shape[0], 1)*self.sos_token - c = c.long().to(x.device) - if self.quantize_interface: - return c, None, [None, None, c] - return c diff --git a/Control-Color/taming/modules/vqvae/quantize.py b/Control-Color/taming/modules/vqvae/quantize.py deleted file mode 100644 index d75544e41fa01bce49dd822b1037963d62f79b51..0000000000000000000000000000000000000000 --- a/Control-Color/taming/modules/vqvae/quantize.py +++ /dev/null @@ -1,445 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from torch import einsum -from einops import rearrange - - -class VectorQuantizer(nn.Module): - """ - see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py - ____________________________________________ - Discretization bottleneck part of the VQ-VAE. - Inputs: - - n_e : number of embeddings - - e_dim : dimension of embedding - - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - _____________________________________________ - """ - - # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for - # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be - # used wherever VectorQuantizer has been used before and is additionally - # more efficient. - def __init__(self, n_e, e_dim, beta): - super(VectorQuantizer, self).__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - def forward(self, z): - """ - Inputs the output of the encoder network z and maps it to a discrete - one-hot vector that is the index of the closest embedding vector e_j - z (continuous) -> z_q (discrete) - z.shape = (batch, channel, height, width) - quantization pipeline: - 1. get encoder input (B,C,H,W) - 2. flatten input to (B*H*W,C) - """ - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ - torch.sum(self.embedding.weight**2, dim=1) - 2 * \ - torch.matmul(z_flattened, self.embedding.weight.t()) - - ## could possible replace this here - # #\start... - # find closest encodings - min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - - min_encodings = torch.zeros( - min_encoding_indices.shape[0], self.n_e).to(z) - min_encodings.scatter_(1, min_encoding_indices, 1) - - # dtype min encodings: torch.float32 - # min_encodings shape: torch.Size([2048, 512]) - # min_encoding_indices.shape: torch.Size([2048, 1]) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) - #.........\end - - # with: - # .........\start - #min_encoding_indices = torch.argmin(d, dim=1) - #z_q = self.embedding(min_encoding_indices) - # ......\end......... (TODO) - - # compute loss for embedding - loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ - torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # perplexity - e_mean = torch.mean(min_encodings, dim=0) - perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - # TODO: check for more easy handling with nn.Embedding - min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) - min_encodings.scatter_(1, indices[:,None], 1) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings.float(), self.embedding.weight) - - if shape is not None: - z_q = z_q.view(shape) - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class GumbelQuantize(nn.Module): - """ - credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) - Gumbel Softmax trick quantizer - Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 - https://arxiv.org/abs/1611.01144 - """ - def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, - kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, - remap=None, unknown_index="random"): - super().__init__() - - self.embedding_dim = embedding_dim - self.n_embed = n_embed - - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - - self.proj = nn.Conv2d(num_hiddens, n_embed, 1) - self.embed = nn.Embedding(n_embed, embedding_dim) - - self.use_vqinterface = use_vqinterface - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed+1 - print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") - else: - self.re_embed = n_embed - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - match = (inds[:,:,None]==used[None,None,...]).long() - new = match.argmax(-1) - unknown = match.sum(2)<1 - if self.unknown_index == "random": - new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds>=self.used.shape[0]] = 0 # simply set to zero - back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, return_logits=False): - # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work - hard = self.straight_through if self.training else True - temp = self.temperature if temp is None else temp - - logits = self.proj(z) - if self.remap is not None: - # continue only with used logits - full_zeros = torch.zeros_like(logits) - logits = logits[:,self.used,...] - - soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) - if self.remap is not None: - # go back to all entries but unused set to zero - full_zeros[:,self.used,...] = soft_one_hot - soft_one_hot = full_zeros - z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() - - ind = soft_one_hot.argmax(dim=1) - if self.remap is not None: - ind = self.remap_to_used(ind) - if self.use_vqinterface: - if return_logits: - return z_q, diff, (None, None, ind), logits - return z_q, diff, (None, None, ind) - return z_q, diff, ind - - def get_codebook_entry(self, indices, shape): - b, h, w, c = shape - assert b*h*w == indices.shape[0] - indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) - if self.remap is not None: - indices = self.unmap_to_all(indices) - one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() - z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) - return z_q - - -class VectorQuantizer2(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly - avoids costly matrix multiplications and allows for post-hoc remapping of indices. - """ - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", - sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed+1 - print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - match = (inds[:,:,None]==used[None,None,...]).long() - new = match.argmax(-1) - unknown = match.sum(2)<1 - if self.unknown_index == "random": - new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds>=self.used.shape[0]] = 0 # simply set to zero - back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" - assert rescale_logits==False, "Only for interface compatible with Gumbel" - assert return_logits==False, "Only for interface compatible with Gumbel" - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, 'b c h w -> b h w c').contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ - torch.sum(self.embedding.weight**2, dim=1) - 2 * \ - torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ - torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ - torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape( - z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0],-1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - -class EmbeddingEMA(nn.Module): - def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): - super().__init__() - self.decay = decay - self.eps = eps - weight = torch.randn(num_tokens, codebook_dim) - self.weight = nn.Parameter(weight, requires_grad = False) - self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) - self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) - self.update = True - - def forward(self, embed_id): - return F.embedding(embed_id, self.weight) - - def cluster_size_ema_update(self, new_cluster_size): - self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) - - def embed_avg_ema_update(self, new_embed_avg): - self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) - - def weight_update(self, num_tokens): - n = self.cluster_size.sum() - smoothed_cluster_size = ( - (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n - ) - #normalize embedding average with smoothed cluster size - embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) - self.weight.data.copy_(embed_normalized) - - -class EMAVectorQuantizer(nn.Module): - def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, - remap=None, unknown_index="random"): - super().__init__() - self.codebook_dim = codebook_dim - self.num_tokens = num_tokens - self.beta = beta - self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed+1 - print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") - else: - self.re_embed = n_embed - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - match = (inds[:,:,None]==used[None,None,...]).long() - new = match.argmax(-1) - unknown = match.sum(2)<1 - if self.unknown_index == "random": - new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds>=self.used.shape[0]] = 0 # simply set to zero - back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) - return back.reshape(ishape) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - #z, 'b c h w -> b h w c' - z = rearrange(z, 'b c h w -> b h w c') - z_flattened = z.reshape(-1, self.codebook_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ - self.embedding.weight.pow(2).sum(dim=1) - 2 * \ - torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' - - - encoding_indices = torch.argmin(d, dim=1) - - z_q = self.embedding(encoding_indices).view(z.shape) - encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) - avg_probs = torch.mean(encodings, dim=0) - perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) - - if self.training and self.embedding.update: - #EMA cluster size - encodings_sum = encodings.sum(0) - self.embedding.cluster_size_ema_update(encodings_sum) - #EMA embedding average - embed_sum = encodings.transpose(0,1) @ z_flattened - self.embedding.embed_avg_ema_update(embed_sum) - #normalize embed_avg and update weight - self.embedding.weight_update(self.num_tokens) - - # compute loss for embedding - loss = self.beta * F.mse_loss(z_q.detach(), z) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - #z_q, 'b h w c -> b c h w' - z_q = rearrange(z_q, 'b h w c -> b c h w') - return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/Control-Color/taming/util.py b/Control-Color/taming/util.py deleted file mode 100644 index 06053e5defb87977f9ab07e69bf4da12201de9b7..0000000000000000000000000000000000000000 --- a/Control-Color/taming/util.py +++ /dev/null @@ -1,157 +0,0 @@ -import os, hashlib -import requests -from tqdm import tqdm - -URL_MAP = { - "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" -} - -CKPT_MAP = { - "vgg_lpips": "vgg.pth" -} - -MD5_MAP = { - "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" -} - - -def download(url, local_path, chunk_size=1024): - os.makedirs(os.path.split(local_path)[0], exist_ok=True) - with requests.get(url, stream=True) as r: - total_size = int(r.headers.get("content-length", 0)) - with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: - with open(local_path, "wb") as f: - for data in r.iter_content(chunk_size=chunk_size): - if data: - f.write(data) - pbar.update(chunk_size) - - -def md5_hash(path): - with open(path, "rb") as f: - content = f.read() - return hashlib.md5(content).hexdigest() - - -def get_ckpt_path(name, root, check=False): - assert name in URL_MAP - path = os.path.join(root, CKPT_MAP[name]) - if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): - print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) - download(URL_MAP[name], path) - md5 = md5_hash(path) - assert md5 == MD5_MAP[name], md5 - return path - - -class KeyNotFoundError(Exception): - def __init__(self, cause, keys=None, visited=None): - self.cause = cause - self.keys = keys - self.visited = visited - messages = list() - if keys is not None: - messages.append("Key not found: {}".format(keys)) - if visited is not None: - messages.append("Visited: {}".format(visited)) - messages.append("Cause:\n{}".format(cause)) - message = "\n".join(messages) - super().__init__(message) - - -def retrieve( - list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False -): - """Given a nested list or dict return the desired value at key expanding - callable nodes if necessary and :attr:`expand` is ``True``. The expansion - is done in-place. - - Parameters - ---------- - list_or_dict : list or dict - Possibly nested list or dictionary. - key : str - key/to/value, path like string describing all keys necessary to - consider to get to the desired value. List indices can also be - passed here. - splitval : str - String that defines the delimiter between keys of the - different depth levels in `key`. - default : obj - Value returned if :attr:`key` is not found. - expand : bool - Whether to expand callable nodes on the path or not. - - Returns - ------- - The desired value or if :attr:`default` is not ``None`` and the - :attr:`key` is not found returns ``default``. - - Raises - ------ - Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is - ``None``. - """ - - keys = key.split(splitval) - - success = True - try: - visited = [] - parent = None - last_key = None - for key in keys: - if callable(list_or_dict): - if not expand: - raise KeyNotFoundError( - ValueError( - "Trying to get past callable node with expand=False." - ), - keys=keys, - visited=visited, - ) - list_or_dict = list_or_dict() - parent[last_key] = list_or_dict - - last_key = key - parent = list_or_dict - - try: - if isinstance(list_or_dict, dict): - list_or_dict = list_or_dict[key] - else: - list_or_dict = list_or_dict[int(key)] - except (KeyError, IndexError, ValueError) as e: - raise KeyNotFoundError(e, keys=keys, visited=visited) - - visited += [key] - # final expansion of retrieved value - if expand and callable(list_or_dict): - list_or_dict = list_or_dict() - parent[last_key] = list_or_dict - except KeyNotFoundError as e: - if default is None: - raise e - else: - list_or_dict = default - success = False - - if not pass_success: - return list_or_dict - else: - return list_or_dict, success - - -if __name__ == "__main__": - config = {"keya": "a", - "keyb": "b", - "keyc": - {"cc1": 1, - "cc2": 2, - } - } - from omegaconf import OmegaConf - config = OmegaConf.create(config) - print(config) - retrieve(config, "keya") -