diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9ac2dd89b438fcac05c40e030e26bc5a3aaad980 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..ab4cdf0f4dc0b38914cefbc15e7a08cd3342c6c1 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,32 @@ +from PIL import Image import gradio as gr +from huggingface_hub import hf_hub_download +from model import Model +from app_canny import create_demo as create_demo_canny +from app_depth import create_demo as create_demo_depth +import os -def greet(name): - return "Hello " + name + "!!" -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() +hf_hub_download('wondervictor/ControlAR', filename='canny_MR.safetensors', cache_dir='./checkpoints/') +hf_hub_download('wondervictor/ControlAR', filename='depth_MR.safetensors', cache_dir='./checkpoints/') + + +DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)." +SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1" +model = Model() +device = "cuda" +with gr.Blocks(css="style.css") as demo: + gr.Markdown(DESCRIPTION) + gr.DuplicateButton( + value="Duplicate Space for private use", + elem_id="duplicate-button", + visible=SHOW_DUPLICATE_BUTTON, + ) + with gr.Tabs(): + with gr.TabItem("Depth"): + create_demo_depth(model.process_depth) + with gr.TabItem("Canny"): + create_demo_canny(model.process_canny) + +if __name__ == "__main__": + demo.queue().launch(share=False, server_name="0.0.0.0") diff --git a/app_canny.py b/app_canny.py new file mode 100644 index 0000000000000000000000000000000000000000..b532a9d3d5bb6de89c199ab43179acd739ae171b --- /dev/null +++ b/app_canny.py @@ -0,0 +1,100 @@ +import gradio as gr +import random +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, 100000000) + return seed +examples = [ + [ + "condition/example/t2i/multigen/doll.png", + "A stuffed animal wearing a mask and a leash, sitting on a blanket", + "(512, 512)" + ], + [ + "condition/example/t2i/multigen/girl.png", + "An anime style girl with blue hair", + "(512, 512)" + ], + [ + "condition/example/t2i/multi_resolution/bird.jpg", + "colorful bird", + "(921, 564)" + ], +] +def create_demo(process): + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + image = gr.Image() + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button("Run") + with gr.Accordion("Advanced options", open=False): + canny_low_threshold = gr.Slider( + label="Canny low threshold", minimum=0, maximum=1000, value=100, step=50 + ) + canny_high_threshold = gr.Slider( + label="Canny high threshold", minimum=0, maximum=1000, value=200, step=50 + ) + cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1) + relolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16) + top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K') + top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P") + temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature') + seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + with gr.Column(): + result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down") + gr.Examples( + examples=examples, + inputs=[ + image, + prompt, + relolution, + ], + outputs=result, + fn=process, + ) + inputs = [ + image, + prompt, + cfg_scale, + temperature, + top_k, + top_p, + seed, + canny_low_threshold, + canny_high_threshold, + ] + prompt.submit( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=process, + inputs=inputs, + outputs=result, + api_name=False, + ) + run_button.click( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=process, + inputs=inputs, + outputs=result, + api_name="canny", + ) + return demo +if __name__ == "__main__": + from model import Model + model = Model() + demo = create_demo(model.process_canny) + demo.queue().launch( + share=False, + server_name="0.0.0.0" + ) \ No newline at end of file diff --git a/app_depth.py b/app_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5cb989378bdf1a0c869bebc9370e70fe09c324 --- /dev/null +++ b/app_depth.py @@ -0,0 +1,92 @@ +import gradio as gr +import random +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, 100000000) + return seed +examples = [ + [ + "condition/example/t2i/multigen/sofa.png", + "The red sofa in the living room has several pillows on it", + "(512, 512)" + ], + [ + "condition/example/t2i/multigen/house.png", + "A brick house with a chimney under a starry sky.", + "(512, 512)" + ], + [ + "condition/example/t2i/multi_resolution/car.jpg", + "a sport car", + "(448, 768)" + ] +] +def create_demo(process): + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + image = gr.Image() + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button("Run") + with gr.Accordion("Advanced options", open=False): + cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1) + resolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16) + top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K') + top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P") + temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature') + seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + with gr.Column(): + result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down") + gr.Examples( + examples=examples, + inputs=[ + image, + prompt, + resolution, + ], + outputs=result, + fn=process, + ) + inputs = [ + image, + prompt, + cfg_scale, + temperature, + top_k, + top_p, + seed, + ] + prompt.submit( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=process, + inputs=inputs, + outputs=result, + api_name=False, + ) + run_button.click( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=process, + inputs=inputs, + outputs=result, + api_name="canny", + ) + return demo +if __name__ == "__main__": + from model import Model + model = Model() + demo = create_demo(model.process_depth) + demo.queue().launch( + share=False, + server_name="0.0.0.0" + ) \ No newline at end of file diff --git a/autoregressive/models/README.md b/autoregressive/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..08a335a767cf323f41b85d2933b5892cbbfcb239 --- /dev/null +++ b/autoregressive/models/README.md @@ -0,0 +1,6 @@ +Download the vit weight first + +ViT-small: https://huggingface.co/WinKawaks/vit-small-patch16-224 \ +Dinov2-small: https://huggingface.co/facebook/dinov2-small + +Put them here \ No newline at end of file diff --git a/autoregressive/models/dinov2_adapter.py b/autoregressive/models/dinov2_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..81447d400899da754a3f58a6be464c22aa06bc11 --- /dev/null +++ b/autoregressive/models/dinov2_adapter.py @@ -0,0 +1,36 @@ +from transformers import AutoImageProcessor, AutoModel +from PIL import Image +import requests +import torch +import torch.nn as nn + + +class Dinov2_Adapter(nn.Module): + def __init__(self, input_dim=1, output_dim=768, attention=False, pool=False, nheads=8, dropout=0.1, adapter_size='small', condition_type='canny'): + super(Dinov2_Adapter, self).__init__() + print(f"Choose adapter size: {adapter_size}") + print(f"condition type: {condition_type}") + self.model = AutoModel.from_pretrained(f'autoregressive/models/dinov2-{adapter_size}') + self.condition_type = condition_type + + def to_patch14(self, input): + H, W = input.shape[2:] + new_H = (H // 16) * 14 + new_W = (W // 16) * 14 + if self.condition_type in ['canny', 'seg']: + output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='nearest')#, align_corners=True) canny, seg + else: + output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='bicubic', align_corners=True) # depth, lineart, hed + return output + + def forward(self, x): + x = self.to_patch14(x) + x = self.model(x) + return x.last_hidden_state[:, 1:] + + +if __name__ == '__main__': + model = Dinov2_Adapter().cuda() + inputs = torch.randn(4,3,512,512).cuda() + outputs = model(inputs) + print(outputs.shape) \ No newline at end of file diff --git a/autoregressive/models/generate.py b/autoregressive/models/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..59179e7af246b9beec1a6164fca3ab2d4574e7a3 --- /dev/null +++ b/autoregressive/models/generate.py @@ -0,0 +1,204 @@ +# Modified from: +# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py +# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py +import torch +import torch.nn as nn +from torch.nn import functional as F +import torch._dynamo.config +import torch._inductor.config +import copy +import time +# torch._inductor.config.coordinate_descent_tuning = True +# torch._inductor.config.triton.unique_kernel_names = True +# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + +### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html +def top_k_top_p_filtering( + logits, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + # import pdb;pdb.set_trace() + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sample_logits=True): + logits = logits[:, -1, :] / max(temperature, 1e-5) + if top_k > 0 or top_p < 1.0: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + probs = F.softmax(logits, dim=-1) + # values, indices = torch.max(probs, dim=1, keepdim=True) + # mask = (probs == values).float() + # probs = probs * (1 - mask) + # values, indices = torch.max(probs, dim=1, keepdim=True) + # mask = (probs == values).float() + # probs = probs * (1 - mask) + if sample_logits: + idx = torch.multinomial(probs, num_samples=1) + else: + _, idx = torch.topk(probs, k=1, dim=-1) + return idx, probs + + +def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs): + logits = logits / max(temperature, 1e-5) + if top_k > 0 or top_p < 1.0: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs): + if cfg_scale > 1.0: + logits, _ = model(None, cond_idx, input_pos, condition=condition) + logits_combined = logits + cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + else: + logits, _ = model(None, cond_idx, input_pos, condition=condition) + + return sample(logits, **sampling_kwargs)[0] + + +def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, condition: torch.Tensor, **sampling_kwargs): + assert input_pos.shape[-1] == 1 + if cfg_scale > 1.0: + x_combined = torch.cat([x, x]) + logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos, condition=condition) + logits_combined = logits + cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) + if cfg_flag: + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + else: + logits = cond_logits + else: + logits, _ = model(x, cond_idx=None, input_pos=input_pos, condition=None) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, + cfg_scale: float, cfg_interval: int, condition: torch.Tensor, + **sampling_kwargs): + new_tokens, new_probs = [], [] + cfg_flag = True + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + if cfg_interval > -1 and i > cfg_interval: + cfg_flag = False + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, cfg_scale, cfg_flag, condition=condition, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(-1, 1) + + return new_tokens, new_probs + + +@torch.no_grad() +def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs): + if condition is not None: + condition = model.adapter(condition) + condition = model.adapter_mlp(condition) + if model.model_type == 'c2i': + if cfg_scale > 1.0: + cond_null = torch.ones_like(cond) * model.num_classes + cond_combined = torch.cat([cond, cond_null]) + if condition is not None: + condition_null = torch.zeros_like(condition) + condition_combined = torch.cat((condition, condition_null), dim=0) + else: + condition_combined = None + else: + cond_combined = cond + if condition is not None: + condition_combined = condition + else: + condition_combined = None + T = 1+condition_token_nums + elif model.model_type == 't2i': + if cfg_scale > 1.0: + cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding + cond_combined = torch.cat([cond, cond_null]) + + if condition is not None: + condition_null = torch.zeros_like(condition) + condition_combined = torch.cat((condition, condition_null), dim=0) + else: + condition_combined = None + else: + cond_combined = cond + if condition is not None: + condition_combined = condition + else: + condition_combined = None + T = cond.shape[1] + else: + raise Exception("please check model type") + + T_new = T + max_new_tokens + max_seq_length = T_new + max_batch_size = cond.shape[0] + + device = cond.device + with torch.device(device): + max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size + model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype) + + if emb_masks is not None: + assert emb_masks.shape[0] == max_batch_size + assert emb_masks.shape[-1] == T + if cfg_scale > 1.0: + model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1) + else: + model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1) + + eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device) + model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix + + # create an empty tensor of the expected final shape and fill in the current tokens + seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device) + input_pos = torch.arange(0, T, device=device) + next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, **sampling_kwargs) + seq[:, T:T+1] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs) + seq[:, T+1:] = torch.cat(generated_tokens, dim=1) + return seq[:, T:] diff --git a/autoregressive/models/gpt_t2i.py b/autoregressive/models/gpt_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..be883e91c03079c9a06a669d5f6ad8380de93c7e --- /dev/null +++ b/autoregressive/models/gpt_t2i.py @@ -0,0 +1,561 @@ +# Modified from: +# VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py +# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py +# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py +# llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py +# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py +from dataclasses import dataclass +from typing import Optional, List + + +import torch +import torch.nn as nn +from torch.nn import functional as F +from utils.drop_path import DropPath +# from autoregressive.models.vit_adapter import ViT_Adapter +from autoregressive.models.dinov2_adapter import Dinov2_Adapter + + +def get_causal_mask(seq_length): + mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool) + mask = mask.masked_fill(mask, float('-inf')) + mask = mask.masked_fill(~mask, float(0.0)) + return mask + +def find_multiple(n: int, k: int): + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layer: int = 32 + n_head: int = 32 + n_kv_head: Optional[int] = None + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + rope_base: float = 10000 + norm_eps: float = 1e-5 + initializer_range: float = 0.02 + + token_dropout_p: float = 0.1 + attn_dropout_p: float = 0.0 + resid_dropout_p: float = 0.1 + ffn_dropout_p: float = 0.1 + drop_path_rate: float = 0.0 + + num_classes: int = 1000 + caption_dim: int = 2048 + class_dropout_prob: float = 0.1 + model_type: str = 'c2i' + + vocab_size: int = 16384 + cls_token_num: int = 1 + block_size: int = 256 + max_batch_size: int = 32 + max_seq_len: int = 2048 + adapter_size: str = 'small' + condition_type: str = 'canny' + + + +################################################################################# +# Embedding Layers for Class Labels # +################################################################################# +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels, drop_ids + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels,drop_ids = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels).unsqueeze(1) + if (train and use_dropout) or (force_drop_ids is not None): + return embeddings,drop_ids + else: + return embeddings + + +class ConditionEmbedder(nn.Module): + """ + Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384): + super().__init__() + self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size) + self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None, drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + if drop_ids is None: + drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + + caption = torch.where(drop_ids[:, None, None], self.uncond_embedding[:caption.shape[1]], caption) + return caption + + def forward(self, caption, train, force_drop_ids=None, drop_ids=None): + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids, drop_ids) + embeddings = self.cap_proj(caption) + return embeddings + +################################################################################# +# Embedding Layers for Text Feature # +################################################################################# +class CaptionEmbedder(nn.Module): + """ + Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120): + super().__init__() + self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size) + self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption) + return caption, drop_ids + + def forward(self, caption, train, force_drop_ids=None): + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption, drop_ids = self.token_drop(caption, force_drop_ids) + embeddings = self.cap_proj(caption) + if (train and use_dropout) or (force_drop_ids is not None): + return embeddings,drop_ids + else: + return embeddings + + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features, out_features): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=False) + self.act = nn.GELU(approximate='tanh') + self.fc2 = nn.Linear(hidden_features, out_features, bias=False) + + nn.init.zeros_(self.fc1.weight) + nn.init.zeros_(self.fc2.weight) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +################################################################################# +# GPT Model # +################################################################################# +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + hidden_dim = 4 * config.dim + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if config.ffn_dim_multiplier is not None: + hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) + hidden_dim = find_multiple(hidden_dim, config.multiple_of) + + self.w1 = nn.Linear(config.dim, hidden_dim, bias=False) + self.w3 = nn.Linear(config.dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, config.dim, bias=False) + self.ffn_dropout = nn.Dropout(config.ffn_dropout_p) + + def forward(self, x): + return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) + + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype): + super().__init__() + cache_shape = (max_batch_size, n_head, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + self.dim = config.dim + self.head_dim = config.dim // config.n_head + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head + total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim + + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + # regularization + self.attn_dropout_p = config.attn_dropout_p + self.resid_dropout = nn.Dropout(config.resid_dropout_p) + + def forward( + self, x: torch.Tensor, freqs_cis: torch.Tensor = None, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None + ): + bsz, seqlen, _ = x.shape + kv_size = self.n_kv_head * self.head_dim + xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) + + xq = apply_rotary_emb(xq, freqs_cis) + xk = apply_rotary_emb(xk, freqs_cis) + + xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) + + if self.kv_cache is not None: + keys, values = self.kv_cache.update(input_pos, xk, xv) + else: + keys, values = xk, xv + keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1) + values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1) + + output = F.scaled_dot_product_attention( + xq, keys, values, + attn_mask=mask, + is_causal=True if mask is None else False, # is_causal=False is for KV cache + dropout_p=self.attn_dropout_p if self.training else 0) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + output = self.resid_dropout(self.wo(output)) + return output + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs, drop_path: float): + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward( + self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None): + h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask)) + out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) + return out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.n_layer = config.n_layer + self.block_size = config.block_size + self.num_classes = config.num_classes + self.model_type = config.model_type + self.cls_token_num = config.cls_token_num + self.layer_internal = config.n_layer // 3 + # self.adapter = Adapter(output_dim=768) + # self.adapter = ViT_Adapter() + # self.adapter = DeiT_Adapter() + self.adapter = Dinov2_Adapter(adapter_size=config.adapter_size, condition_type=config.condition_type) + # self.adapter = EVA_Adapter() + if config.adapter_size == "small": + self.adapter_mlp = MLP(384, config.dim, config.dim) + elif config.adapter_size == 'base': + self.adapter_mlp = MLP(768, config.dim, config.dim) + + if self.model_type == 'c2i': + self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob) + elif self.model_type == 't2i': + self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob) + else: + raise Exception("please check model type") + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.tok_dropout = nn.Dropout(config.token_dropout_p) + + self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.condition_mlp = ConditionEmbedder(self.block_size, config.dim, config.class_dropout_prob, self.block_size, config.vocab_size) + self.condition_layers = torch.nn.ModuleList() + for layer_id in range(3): + self.condition_layers.append(MLP(config.dim,config.dim,config.dim)) + + # transformer blocks + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)] + self.layers = torch.nn.ModuleList() + for layer_id in range(config.n_layer): + self.layers.append(TransformerBlock(config, dpr[layer_id])) + + # output layer + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + # 2d rotary pos embedding + grid_size = int(self.block_size ** 0.5) + assert grid_size * grid_size == self.block_size + self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num) + + # KVCache + self.max_batch_size = -1 + self.max_seq_length = -1 + + self.initialize_weights() + self.condition_token = None + self.mask = get_causal_mask(256) + self.global_token = None + + + def initialize_weights(self): + # Initialize nn.Linear and nn.Embedding + self.apply(self._init_weights) + + # Zero-out output layers: + nn.init.constant_(self.output.weight, 0) + + + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + + + def setup_caches(self, max_batch_size, max_seq_length, dtype): + # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + # return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) # + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype) + + causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1) + grid_size = int(self.config.block_size ** 0.5) + assert grid_size * grid_size == self.block_size + self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num) + + + + def forward( + self, + idx: torch.Tensor, + cond_idx: torch.Tensor, # cond_idx_or_embed + input_pos: Optional[torch.Tensor] = None, + targets: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + valid: Optional[torch.Tensor] = None, + condition: Optional[torch.Tensor] = None + ): + if idx is not None and cond_idx is not None: # training or naive inference + cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training) + cond_embeddings = cond_embeddings[:,:self.cls_token_num] + token_embeddings = self.tok_embeddings(idx) + if condition is not None: + condition_embeddings = self.adapter(condition) + condition_embeddings = self.adapter_mlp(condition_embeddings) + self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids) + token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) + + h = self.tok_dropout(token_embeddings) + self.freqs_cis = self.freqs_cis.to(h.device) + else: + if cond_idx is not None: # prefill in inference + token_embeddings = self.cls_embedding(cond_idx, train=self.training) + token_embeddings = token_embeddings[:,:self.cls_token_num] + if condition is not None: + condition_embeddings = self.condition_mlp(condition.to(torch.bfloat16),train=self.training) + self.condition_token = condition_embeddings + + else: # decode_n_tokens(kv cache) in inference + token_embeddings = self.tok_embeddings(idx) + bs = token_embeddings.shape[0] + mask = self.causal_mask[:bs, None, input_pos] + h = self.tok_dropout(token_embeddings) + self.freqs_cis = self.freqs_cis + + if self.training: + freqs_cis = self.freqs_cis[:token_embeddings.shape[1]] + else: + freqs_cis = self.freqs_cis[input_pos] + # transformer blocks + for i, layer in enumerate(self.layers): + if i%self.layer_internal == 0: + if self.training: + h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token) + else: + if len(input_pos)>1: + h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1]) + else: + h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1]) + h = layer(h, freqs_cis, input_pos, mask) + # output layers + h = self.norm(h) + logits = self.output(h).float() + + if self.training: + logits = logits[:, self.cls_token_num - 1:].contiguous() + # if we are given some desired targets also calculate the loss + loss = None + if valid is not None: + loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1) + loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1) + elif targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + + return logits, loss + + + def get_fsdp_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120): + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2) + cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2) + return cond_cache + + +def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120): + # split the dimension into half, one for x and one for y + half_dim = n_elem // 2 + freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim)) + t = torch.arange(grid_size, device=freqs.device) + freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2) + freqs_grid = torch.concat([ + freqs[:, None, :].expand(-1, grid_size, -1), + freqs[None, :, :].expand(grid_size, -1, -1), + ], dim=-1) # (grid_size, grid_size, head_dim // 2) + cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2) + cache = cache_grid.flatten(0, 1) + cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2) + return cond_cache + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor): + # x: (bs, seq_len, n_head, head_dim) + # freqs_cis (seq_len, head_dim // 2, 2) + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2) + x_out2 = torch.stack([ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], dim=-1) + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + + +################################################################################# +# GPT Configs # +################################################################################# +### text-conditional +def GPT_7B(**kwargs): + return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B + +def GPT_3B(**kwargs): + return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B + +def GPT_1B(**kwargs): + return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B + +### class-conditional +def GPT_XXXL(**kwargs): + return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B + +def GPT_XXL(**kwargs): + return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B + +def GPT_XL(**kwargs): + return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M + +def GPT_L(**kwargs): + return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M + +def GPT_B(**kwargs): + return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M + + +GPT_models = { + 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL, + 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, +} \ No newline at end of file diff --git a/autoregressive/sample/sample_c2i.py b/autoregressive/sample/sample_c2i.py new file mode 100644 index 0000000000000000000000000000000000000000..5174efa87138e45e31fcb552e85e9f1f6e971fae --- /dev/null +++ b/autoregressive/sample/sample_c2i.py @@ -0,0 +1,151 @@ +# Modified from: +# DiT: https://github.com/facebookresearch/DiT/blob/main/sample.py +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision('high') +setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) +setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) +from torchvision.utils import save_image +import os +import sys +current_directory = os.getcwd() +sys.path.append(current_directory) + +from PIL import Image +import time +import argparse +from tokenizer.tokenizer_image.vq_model import VQ_models +from autoregressive.models.gpt import GPT_models +from autoregressive.models.generate import generate +from functools import partial +import torch.nn.functional as F +import numpy as np +import cv2 + + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.set_grad_enabled(False) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + vq_model.to(device) + vq_model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + del checkpoint + print(f"image tokenizer is loaded") + + # create and load gpt model + precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] + latent_size = args.image_size // args.downsample_size + gpt_model = GPT_models[args.gpt_model]( + vocab_size=args.codebook_size, + block_size=latent_size ** 2, + num_classes=args.num_classes, + cls_token_num=args.cls_token_num, + model_type=args.gpt_type, + condition_token_num=args.condition_token_nums, + image_size=args.image_size + ).to(device=device, dtype=precision) + + _, file_extension = os.path.splitext(args.gpt_ckpt) + if file_extension.lower() == '.safetensors': + from safetensors.torch import load_file + model_weight = load_file(args.gpt_ckpt) + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + else: + checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") + if "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "module" in checkpoint: # deepspeed + model_weight = checkpoint["module"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + del checkpoint + print(f"gpt model is loaded") + + if args.compile: + print(f"compiling the model...") + gpt_model = torch.compile( + gpt_model, + mode="reduce-overhead", + fullgraph=True + ) # requires PyTorch 2.0 (optional) + else: + print(f"no need to compile model in demo") + + condition_null = None + if args.condition_type == 'canny': + sample_list = [650, 2312, 15000, 48850] # canny + elif args.condition_type == 'depth': + sample_list = [101, 4351, 10601, 48901] + + class_labels = [np.load(f"condition/example/c2i/{args.condition_type}/{i}.npy")[0] for i in sample_list] + condition_imgs = [np.array(Image.open((f"condition/example/c2i/{args.condition_type}/{i}.png")))[None,None,...] for i in sample_list] + condition_imgs = torch.from_numpy(np.concatenate(condition_imgs, axis=0)).to(device).to(torch.float32)/255 + condition_imgs = 2*(condition_imgs-0.5) + print(condition_imgs.shape) + c_indices = torch.tensor(class_labels, device=device) + qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size] + t1 = time.time() + + index_sample = generate( + gpt_model, c_indices, latent_size ** 2, condition=condition_imgs.repeat(1,3,1,1).to(precision), condition_null=condition_null, condition_token_nums=args.condition_token_nums, + cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval, + temperature=args.temperature, top_k=args.top_k, + top_p=args.top_p, sample_logits=True, + ) + + sampling_time = time.time() - t1 + print(f"gpt sampling takes about {sampling_time:.2f} seconds.") + + t2 = time.time() + samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] + decoder_time = time.time() - t2 + print(f"decoder takes about {decoder_time:.2f} seconds.") + # Save and display images: + condition_imgs = condition_imgs.repeat(1,3,1,1) + samples = torch.cat((condition_imgs[:4], samples[:4]),dim=0) + save_image(samples, f"sample/example/sample_{args.gpt_type}_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1)) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B") + parser.add_argument("--gpt-ckpt", type=str, default=None) + parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional") + parser.add_argument("--from-fsdp", action='store_true') + parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") + parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--compile", action='store_true', default=False) + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256) + parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=4.0) + parser.add_argument("--cfg-interval", type=float, default=-1) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with") + parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") + parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") + parser.add_argument("--condition-token-nums", type=int, default=0) + parser.add_argument("--condition-type", type=str, default='canny', choices=['canny', 'depth']) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/autoregressive/sample/sample_c2i_ddp.py b/autoregressive/sample/sample_c2i_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..f75c0868486a1365334a80c0c6e3d35e191d6d3e --- /dev/null +++ b/autoregressive/sample/sample_c2i_ddp.py @@ -0,0 +1,188 @@ +# Modified from: +# DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.nn.functional as F +import torch.distributed as dist + +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import math +import argparse + +from tokenizer.tokenizer_image.vq_model import VQ_models +from autoregressive.models.gpt import GPT_models +from autoregressive.models.generate import generate + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + vq_model.to(device) + vq_model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + del checkpoint + + # create and load gpt model + precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] + latent_size = args.image_size // args.downsample_size + gpt_model = GPT_models[args.gpt_model]( + vocab_size=args.codebook_size, + block_size=latent_size ** 2, + num_classes=args.num_classes, + cls_token_num=args.cls_token_num, + model_type=args.gpt_type, + ).to(device=device, dtype=precision) + checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") + if args.from_fsdp: # fsdp + model_weight = checkpoint + elif "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "module" in checkpoint: # deepspeed + model_weight = checkpoint["module"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight, maybe add --from-fsdp to run command") + # if 'freqs_cis' in model_weight: + # model_weight.pop('freqs_cis') + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + del checkpoint + + if args.compile: + print(f"compiling the model...") + gpt_model = torch.compile( + gpt_model, + mode="reduce-overhead", + fullgraph=True + ) # requires PyTorch 2.0 (optional) + else: + print(f"no model compile") + + # Create folder to save samples: + model_string_name = args.gpt_model.replace("/", "-") + if args.from_fsdp: + ckpt_string_name = args.gpt_ckpt.split('/')[-2] + else: + ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "") + folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-size-{args.image_size_eval}-{args.vq_model}-" \ + f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \ + f"cfg-{args.cfg_scale}-seed-{args.global_seed}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: + total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) + if rank == 0: + print(f"Total number of images that will be sampled: {total_samples}") + assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" + samples_needed_this_gpu = int(total_samples // dist.get_world_size()) + assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" + iterations = int(samples_needed_this_gpu // n) + pbar = range(iterations) + pbar = tqdm(pbar) if rank == 0 else pbar + total = 0 + for _ in pbar: + # Sample inputs: + c_indices = torch.randint(0, args.num_classes, (n,), device=device) + qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size] + + index_sample = generate( + gpt_model, c_indices, latent_size ** 2, + cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval, + temperature=args.temperature, top_k=args.top_k, + top_p=args.top_p, sample_logits=True, + ) + + samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] + if args.image_size_eval != args.image_size: + samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic') + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, sample in enumerate(samples): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + total += global_batch_size + + # Make sure all processes have finished saving their samples before attempting to convert to .npz + dist.barrier() + if rank == 0: + create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) + print("Done.") + dist.barrier() + dist.destroy_process_group() + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B") + parser.add_argument("--gpt-ckpt", type=str, default=None) + parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional") + parser.add_argument("--from-fsdp", action='store_true') + parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") + parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--compile", action='store_true', default=True) + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384) + parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256) + parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=1.5) + parser.add_argument("--cfg-interval", type=float, default=-1) + parser.add_argument("--sample-dir", type=str, default="samples") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--num-fid-samples", type=int, default=5000) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--top-k", type=int, default=0,help="top-k value to sample with") + parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") + parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/autoregressive/sample/sample_t2i.py b/autoregressive/sample/sample_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5905b9c2ab31a4ef9bbbf7278b6b7971360428 --- /dev/null +++ b/autoregressive/sample/sample_t2i.py @@ -0,0 +1,215 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision('high') +setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed +setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed +from torchvision.utils import save_image + +import os +import sys +current_directory = os.getcwd() +sys.path.append(current_directory) +import time +import argparse +from tokenizer.tokenizer_image.vq_model import VQ_models +from language.t5 import T5Embedder +from autoregressive.models.gpt import GPT_models +from autoregressive.models.gpt_t2i import GPT_models +from autoregressive.models.generate import generate +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from dataset.t2i_control import build_t2i_control_code +from accelerate import Accelerator +from dataset.build import build_dataset +from pathlib import Path +from accelerate.utils import ProjectConfiguration, set_seed +import torch.nn.functional as F +from condition.canny import CannyDetector +from condition.hed import HEDdetector +import numpy as np +from PIL import Image +from condition.lineart import LineArt +import cv2 +from transformers import DPTImageProcessor, DPTForDepthEstimation +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + vq_model.to(device) + vq_model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + del checkpoint + print(f"image tokenizer is loaded") + + # create and load gpt model + precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] + latent_size = args.image_size // args.downsample_size + gpt_model = GPT_models[args.gpt_model]( + block_size=latent_size ** 2, + cls_token_num=args.cls_token_num, + model_type=args.gpt_type, + condition_type=args.condition_type, + ).to(device=device, dtype=precision) + + _, file_extension = os.path.splitext(args.gpt_ckpt) + if file_extension.lower() == '.safetensors': + from safetensors.torch import load_file + model_weight = load_file(args.gpt_ckpt) + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + else: + checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") + if "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "module" in checkpoint: # deepspeed + model_weight = checkpoint["module"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + del checkpoint + print(f"gpt model is loaded") + + if args.compile: + print(f"compiling the model...") + gpt_model = torch.compile( + gpt_model, + mode="reduce-overhead", + fullgraph=True + ) # requires PyTorch 2.0 (optional) + else: + print(f"no need to compile model in demo") + + assert os.path.exists(args.t5_path) + t5_model = T5Embedder( + device=device, + local_cache=True, + cache_dir=args.t5_path, + dir_or_name=args.t5_model_type, + torch_dtype=precision, + model_max_length=args.t5_feature_max_len, + ) + + + if args.condition_type == 'canny': + get_control = CannyDetector() + elif args.condition_type == 'hed': + get_control = HEDdetector().to(device).eval() + elif args.condition_type == 'lineart': + get_control = LineArt() + get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu'))) + get_control.to(device) + elif args.condition_type == 'depth': + processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large") + model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device) + with torch.no_grad(): + + condition_path = args.condition_path + if args.condition_type == 'seg': + condition_img = torch.from_numpy(np.array(Image.open(condition_path))) + condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1) + elif args.condition_type == 'canny': + condition_img = get_control(np.array(Image.open(condition_path))) + condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1) + elif args.condition_type == 'hed': + condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device)) + condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1) + elif args.condition_type == 'lineart': + condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device).float()) + condition_img = condition_img.repeat(2,3,1,1) * 255 + elif args.condition_type == 'depth': + images = Image.open(condition_path) + inputs = processor(images=images, return_tensors="pt", size=(512,512)).to(device) + outputs = model(**inputs) + condition_img = outputs.predicted_depth + condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1) + condition_img = (condition_img * 255 / condition_img.max()) + condition_img = condition_img.to(device) + condition_img = 2*(condition_img/255 - 0.5) + prompts = [args.prompt if args.prompt is not None else "a high-quality image"] + prompts = prompts * 2 + caption_embs, emb_masks = t5_model.get_text_embeddings(prompts) + + if not args.no_left_padding: + print(f"processing left-padding...") + # a naive way to implement left-padding + new_emb_masks = torch.flip(emb_masks, dims=[-1]) + new_caption_embs = [] + for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)): + valid_num = int(emb_mask.sum().item()) + print(f' prompt {idx} token len: {valid_num}') + new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]]) + new_caption_embs.append(new_caption_emb) + new_caption_embs = torch.stack(new_caption_embs) + else: + new_caption_embs, new_emb_masks = caption_embs, emb_masks + c_indices = new_caption_embs * new_emb_masks[:,:, None] + c_emb_masks = new_emb_masks + qzshape = [len(c_indices), args.codebook_embed_dim, args.image_H//args.downsample_size, args.image_W//args.downsample_size] + t1 = time.time() + index_sample = generate( + gpt_model, c_indices, (args.image_H//args.downsample_size)*(args.image_W//args.downsample_size),#latent_size ** 2, + c_emb_masks, condition=condition_img.to(precision), + cfg_scale=args.cfg_scale, + temperature=args.temperature, top_k=args.top_k, + top_p=args.top_p, sample_logits=True, + ) + sampling_time = time.time() - t1 + print(f"Full sampling takes about {sampling_time:.2f} seconds.") + + t2 = time.time() + print(index_sample.shape) + samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] + decoder_time = time.time() - t2 + print(f"decoder takes about {decoder_time:.2f} seconds.") + + samples = torch.cat((condition_img[0:1], samples), dim=0) + save_image(samples, f"sample/example/sample_t2i_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1)) + print(f"image is saved to sample/example/sample_t2i_{args.condition_type}.png") + print(prompts) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt') + parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') + parser.add_argument("--t5-feature-max-len", type=int, default=120) + parser.add_argument("--t5-feature-dim", type=int, default=2048) + parser.add_argument("--no-left-padding", action='store_true', default=False) + parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL") + parser.add_argument("--gpt-ckpt", type=str, default=None) + parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image") + parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input") + parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--compile", action='store_true', default=False) + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768) + parser.add_argument("--image-H", type=int, default=512) + parser.add_argument("--image-W", type=int, default=512) + parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) + parser.add_argument("--cfg-scale", type=float, default=4) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with") + parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") + parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") + + parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny") + parser.add_argument("--prompt", type=str, default='a high-quality image') + parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png') + args = parser.parse_args() + main(args) diff --git a/autoregressive/sample/sample_t2i_MR.py b/autoregressive/sample/sample_t2i_MR.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e312ce64b4604ded80231d1d252a040b7d9e77 --- /dev/null +++ b/autoregressive/sample/sample_t2i_MR.py @@ -0,0 +1,237 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision('high') +setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed +setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed +from torchvision.utils import save_image + +import os +import sys +current_directory = os.getcwd() +sys.path.append(current_directory) +import time +import argparse +from tokenizer.tokenizer_image.vq_model import VQ_models +from language.t5 import T5Embedder +from autoregressive.models.gpt_t2i import GPT_models +from autoregressive.models.generate import generate +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from dataset.t2i_control import build_t2i_control_code +from accelerate import Accelerator +from dataset.build import build_dataset +from pathlib import Path +from accelerate.utils import ProjectConfiguration, set_seed +import torch.nn.functional as F +from condition.canny import CannyDetector +from condition.hed import HEDdetector +import numpy as np +from PIL import Image +from condition.lineart import LineArt +import cv2 +from transformers import DPTImageProcessor, DPTForDepthEstimation +from condition.midas.depth import MidasDetector + + +def resize_image_to_16_multiple(image_path, condition_type='seg'): + image = Image.open(image_path) + width, height = image.size + + if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32 + new_width = (width + 31) // 32 * 32 + new_height = (height + 31) // 32 * 32 + else: + new_width = (width + 15) // 16 * 16 + new_height = (height + 15) // 16 * 16 + + resized_image = image.resize((new_width, new_height)) + return resized_image + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + vq_model.to(device) + vq_model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + del checkpoint + print(f"image tokenizer is loaded") + + # create and load gpt model + precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] + latent_size = args.image_size // args.downsample_size + gpt_model = GPT_models[args.gpt_model]( + block_size=latent_size ** 2, + cls_token_num=args.cls_token_num, + model_type=args.gpt_type, + condition_type=args.condition_type, + ).to(device=device, dtype=precision) + + _, file_extension = os.path.splitext(args.gpt_ckpt) + if file_extension.lower() == '.safetensors': + from safetensors.torch import load_file + model_weight = load_file(args.gpt_ckpt) + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + else: + checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") + if "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "module" in checkpoint: # deepspeed + model_weight = checkpoint["module"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + del checkpoint + print(f"gpt model is loaded") + + if args.compile: + print(f"compiling the model...") + gpt_model = torch.compile( + gpt_model, + mode="reduce-overhead", + fullgraph=True + ) # requires PyTorch 2.0 (optional) + else: + print(f"no need to compile model in demo") + + assert os.path.exists(args.t5_path) + t5_model = T5Embedder( + device=device, + local_cache=True, + cache_dir=args.t5_path, + dir_or_name=args.t5_model_type, + torch_dtype=precision, + model_max_length=args.t5_feature_max_len, + ) + + + if args.condition_type == 'canny': + get_control = CannyDetector() + elif args.condition_type == 'hed': + get_control = HEDdetector().to(device).eval() + elif args.condition_type == 'lineart': + get_control = LineArt() + get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu'))) + get_control.to(device) + elif args.condition_type == 'depth': + processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large") + model_large = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device) + model = MidasDetector(device=device) + with torch.no_grad(): + + condition_img = resize_image_to_16_multiple(args.condition_path, args.condition_type) + W, H = condition_img.size + print(H,W) + if args.condition_type == 'seg': + condition_img = torch.from_numpy(np.array(condition_img)) + condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1) + elif args.condition_type == 'canny': + condition_img = get_control(np.array(condition_img)) + condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1) + elif args.condition_type == 'hed': + condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device)) + condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1) + elif args.condition_type == 'lineart': + condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device).float()) + condition_img = condition_img.repeat(2,3,1,1) * 255 + elif args.condition_type == 'depth': + images = condition_img + if H == W: + inputs = processor(images=images, return_tensors="pt", size=(H,W)).to(device) + outputs = model_large(**inputs) + condition_img = outputs.predicted_depth + condition_img = (condition_img * 255 / condition_img.max()) + else: + condition_img = torch.from_numpy(model(torch.from_numpy(np.array(condition_img)).to(device))).unsqueeze(0) + condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1) + condition_img = condition_img.to(device) + condition_img = 2*(condition_img/255 - 0.5) + prompts = [args.prompt if args.prompt is not None else "a high-quality image"] + prompts = prompts * 2 + caption_embs, emb_masks = t5_model.get_text_embeddings(prompts) + + if not args.no_left_padding: + print(f"processing left-padding...") + # a naive way to implement left-padding + new_emb_masks = torch.flip(emb_masks, dims=[-1]) + new_caption_embs = [] + for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)): + valid_num = int(emb_mask.sum().item()) + print(f' prompt {idx} token len: {valid_num}') + new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]]) + new_caption_embs.append(new_caption_emb) + new_caption_embs = torch.stack(new_caption_embs) + else: + new_caption_embs, new_emb_masks = caption_embs, emb_masks + c_indices = new_caption_embs * new_emb_masks[:,:, None] + c_emb_masks = new_emb_masks + qzshape = [len(c_indices), args.codebook_embed_dim, H//args.downsample_size, W//args.downsample_size] + t1 = time.time() + index_sample = generate( + gpt_model, c_indices, (H//args.downsample_size)*(W//args.downsample_size),#latent_size ** 2, + c_emb_masks, condition=condition_img.to(precision), + cfg_scale=args.cfg_scale, + temperature=args.temperature, top_k=args.top_k, + top_p=args.top_p, sample_logits=True, + ) + sampling_time = time.time() - t1 + print(f"Full sampling takes about {sampling_time:.2f} seconds.") + + t2 = time.time() + print(index_sample.shape) + samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] + decoder_time = time.time() - t2 + print(f"decoder takes about {decoder_time:.2f} seconds.") + + samples = torch.cat((condition_img[0:1], samples), dim=0) + save_image(samples, f"sample/example/sample_t2i_MR_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1)) + print(f"image is saved to sample/example/sample_t2i_MR_{args.condition_type}.png") + print(prompts) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt') + parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') + parser.add_argument("--t5-feature-max-len", type=int, default=120) + parser.add_argument("--t5-feature-dim", type=int, default=2048) + parser.add_argument("--no-left-padding", action='store_true', default=False) + parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL") + parser.add_argument("--gpt-ckpt", type=str, default=None) + parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image") + parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input") + parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--compile", action='store_true', default=False) + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768) + parser.add_argument("--image-H", type=int, default=512) + parser.add_argument("--image-W", type=int, default=512) + parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) + parser.add_argument("--cfg-scale", type=float, default=4) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with") + parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") + parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") + + parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny") + parser.add_argument("--prompt", type=str, default='a high-quality image') + parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png') + args = parser.parse_args() + main(args) diff --git a/autoregressive/sample/sample_t2i_ddp.py b/autoregressive/sample/sample_t2i_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..fb744b79e1cf54de9089d1117dc4d10ed128bec4 --- /dev/null +++ b/autoregressive/sample/sample_t2i_ddp.py @@ -0,0 +1,229 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision('high') +setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed +setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed +import torch.nn.functional as F +import torch.distributed as dist + +import os +import math +import json +import argparse +import pandas as pd +from tqdm import tqdm +from PIL import Image + +from tokenizer.tokenizer_image.vq_model import VQ_models +from language.t5 import T5Embedder +from autoregressive.models.gpt import GPT_models +from autoregressive.models.generate import generate +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + vq_model.to(device) + vq_model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + del checkpoint + print(f"image tokenizer is loaded") + + # create and load gpt model + precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] + latent_size = args.image_size // args.downsample_size + gpt_model = GPT_models[args.gpt_model]( + block_size=latent_size ** 2, + cls_token_num=args.cls_token_num, + model_type=args.gpt_type, + ).to(device=device, dtype=precision) + + checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") + + if "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "module" in checkpoint: # deepspeed + model_weight = checkpoint["module"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + del checkpoint + print(f"gpt model is loaded") + + if args.compile: + print(f"compiling the model...") + gpt_model = torch.compile( + gpt_model, + mode="reduce-overhead", + fullgraph=True + ) # requires PyTorch 2.0 (optional) + else: + print(f"no need to compile model in demo") + + assert os.path.exists(args.t5_path) + t5_model = T5Embedder( + device=device, + local_cache=True, + cache_dir=args.t5_path, + dir_or_name=args.t5_model_type, + torch_dtype=precision, + model_max_length=args.t5_feature_max_len, + ) + print(f"t5 model is loaded") + + # Create folder to save samples: + model_string_name = args.gpt_model.replace("/", "-") + ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "") + prompt_name = args.prompt_csv.split('/')[-1].split('.')[0].lower() + folder_name = f"{model_string_name}-{ckpt_string_name}-{prompt_name}-size-{args.image_size}-size-{args.image_size}-{args.vq_model}-" \ + f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \ + f"cfg-{args.cfg_scale}-seed-{args.global_seed}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(f"{sample_folder_dir}/images", exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}/images") + dist.barrier() + + df = pd.read_csv(args.prompt_csv, delimiter='\t') + prompt_list = df['Prompt'].tolist() + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + num_fid_samples = min(args.num_fid_samples, len(prompt_list)) + # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: + total_samples = int(math.ceil(num_fid_samples / global_batch_size) * global_batch_size) + if rank == 0: + print(f"Total number of images that will be sampled: {total_samples}") + assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" + samples_needed_this_gpu = int(total_samples // dist.get_world_size()) + assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" + iterations = int(samples_needed_this_gpu // n) + pbar = range(iterations) + pbar = tqdm(pbar) if rank == 0 else pbar + total = 0 + for _ in pbar: + # Select text prompt + prompt_batch = [] + for i in range(n): + index = i * dist.get_world_size() + rank + total + prompt_batch.append(prompt_list[index] if index < len(prompt_list) else "a cute dog") + + # Sample inputs: + caption_embs, emb_masks = t5_model.get_text_embeddings(prompt_batch) + + if not args.no_left_padding: + new_emb_masks = torch.flip(emb_masks, dims=[-1]) + new_caption_embs = [] + for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)): + valid_num = int(emb_mask.sum().item()) + # prompt_cur = prompt_batch[idx] + # print(f' prompt {idx} token len: {valid_num} : {prompt_cur}') + new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]]) + new_caption_embs.append(new_caption_emb) + new_caption_embs = torch.stack(new_caption_embs) + + else: + new_caption_embs, new_emb_masks = caption_embs, emb_masks + + c_indices = new_caption_embs * new_emb_masks[:,:, None] + c_emb_masks = new_emb_masks + + qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size] + index_sample = generate( + gpt_model, c_indices, latent_size ** 2, + c_emb_masks, + cfg_scale=args.cfg_scale, + temperature=args.temperature, top_k=args.top_k, + top_p=args.top_p, sample_logits=True, + ) + + samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, sample in enumerate(samples): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/images/{index:06d}.png") + total += global_batch_size + + # Make sure all processes have finished saving their samples before attempting to convert to .npz + dist.barrier() + if rank == 0: + # Save infer result in a jsonl file + json_items = [] + for idx, prompt in enumerate(prompt_list): + image_path = os.path.join(sample_folder_dir, "images", f"{idx:06d}.png") + json_items.append({"text": prompt, "image_path": image_path}) + res_jsonl_path = os.path.join(sample_folder_dir, "result.jsonl") + print(f"Save jsonl to {res_jsonl_path}...") + with open(res_jsonl_path, "w") as f: + for item in json_items: + f.write(json.dumps(item) + "\n") + + # Save captions to txt + caption_path = os.path.join(sample_folder_dir, "captions.txt") + print(f"Save captions to {caption_path}...") + with open(caption_path, "w") as f: + for item in prompt_list: + f.write(f"{item}\n") + print("Done.") + + dist.barrier() + dist.destroy_process_group() + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prompt-csv", type=str, default='evaluations/t2i/PartiPrompts.tsv') + parser.add_argument("--t5-path", type=str, default='pretrained_models/t5-ckpt') + parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') + parser.add_argument("--t5-feature-max-len", type=int, default=120) + parser.add_argument("--t5-feature-dim", type=int, default=2048) + parser.add_argument("--no-left-padding", action='store_true', default=False) + parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL") + parser.add_argument("--gpt-ckpt", type=str, default=None) + parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image") + parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input") + parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--compile", action='store_true', default=False) + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=512) + parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=7.5) + parser.add_argument("--sample-dir", type=str, default="samples_parti", help="samples_coco or samples_parti") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--num-fid-samples", type=int, default=30000) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--top-k", type=int, default=1000, help="top-k value to sample with") + parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") + parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") + args = parser.parse_args() + main(args) diff --git a/checkpoints/vq_ds16_t2i.pt b/checkpoints/vq_ds16_t2i.pt new file mode 100644 index 0000000000000000000000000000000000000000..41cd10c4636fb4c90906b1871f5a77b829e79161 --- /dev/null +++ b/checkpoints/vq_ds16_t2i.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e21fc1318e2e9ee641a07bdad0e20675e9ec35e6e3eb911d58b5d7a2cd8d4cb +size 287920306 diff --git a/condition/README.md b/condition/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cf146b4ed7214413b51cb24bda8857751caeb514 --- /dev/null +++ b/condition/README.md @@ -0,0 +1,23 @@ +Prepare the preprocessing model + +Hed: https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth\ +Lineart: https://huggingface.co/spaces/awacke1/Image-to-Line-Drawings/resolve/main/model.pth\ +depth: https://huggingface.co/lllyasviel/Annotators/blob/main/dpt_hybrid-midas-501f0c75.pt (hybrid for inference)\ + https://huggingface.co/Intel/dpt-large (large for test conditional consistency and fid)\ + +We recommend storing them in the following paths + + |---condition + |---ckpts + |---dpt_large + |---config.json + |---preprocessor_config.json + |---pytorch_model.bin + |---ControlNetHED.pth + |---dpt_hybrid-midas-501f0c75.pt + |---model.pth + |---example + |---midas + . + . + . \ No newline at end of file diff --git a/condition/canny.py b/condition/canny.py new file mode 100644 index 0000000000000000000000000000000000000000..2918b6cd2617b14c77eca0aaba20a0632f349d82 --- /dev/null +++ b/condition/canny.py @@ -0,0 +1,25 @@ +import cv2 +import torch +import numpy as np + + +class CannyDetector: + def __call__(self, img, low_threshold=100, high_threshold=200): + """ + input: array or tensor (H,W,3) + output: array (H,W) + """ + if torch.is_tensor(img): + img = img.cpu().detach().numpy().astype(np.uint8) + return cv2.Canny(img, low_threshold, high_threshold) + + +if __name__ == '__main__': + apply_canny = CannyDetector() + img = cv2.imread('condition/dragon_resize.png') + import numpy as np + print(img.max()) + detected_map = apply_canny(img, 100, 200) + print(detected_map.shape, detected_map.max(), detected_map.min()) + cv2.imwrite('condition/example_canny.jpg', detected_map) + np.save('condition/example_canny.npy', detected_map[None,None]) \ No newline at end of file diff --git a/condition/depth.py b/condition/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4190bc46d7ea55a90a2812ee245713274addadcf --- /dev/null +++ b/condition/depth.py @@ -0,0 +1,47 @@ +from controlnet_aux import LineartDetector +import torch +import cv2 +import numpy as np +from transformers import DPTImageProcessor, DPTForDepthEstimation +class Depth: + def __init__(self, device): + self.model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large") + + def __call__(self, input_image): + """ + input: tensor() + """ + control_image = self.model(input_image) + return np.array(control_image) + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from tqdm import tqdm + from transformers import DPTImageProcessor, DPTForDepthEstimation + from PIL import Image + + image = Image.open('condition/example/t2i/depth/depth.png') + img = cv2.imread('condition/example/t2i/depth/depth.png') + processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large") + model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large") + + inputs = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float()# + inputs = 2*(inputs/255 - 0.5) + inputs = processor(images=image, return_tensors="pt", size=(512,512)) + print(inputs) + with torch.no_grad(): + outputs = model(**inputs) + predicted_depth = outputs.predicted_depth + print(predicted_depth.shape) + prediction = torch.nn.functional.interpolate( + predicted_depth.unsqueeze(1), + size=image.size[::-1], + mode="bicubic", + align_corners=False, + ) + + output = prediction.squeeze().cpu().numpy() + formatted = (output * 255 / np.max(output)).astype("uint8") + + depth = Image.fromarray(formatted) + depth.save('condition/example/t2i/depth/example_depth.jpg') \ No newline at end of file diff --git a/condition/example/t2i/multi_resolution/bird.jpg b/condition/example/t2i/multi_resolution/bird.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f777b39adf619fb5c5e279ab3ff56b6663f447a Binary files /dev/null and b/condition/example/t2i/multi_resolution/bird.jpg differ diff --git a/condition/example/t2i/multi_resolution/car.jpg b/condition/example/t2i/multi_resolution/car.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6ced0af28d9dd90bf816c0ab0bc1747fcf15c30e Binary files /dev/null and b/condition/example/t2i/multi_resolution/car.jpg differ diff --git a/condition/example/t2i/multigen/doll.jpg b/condition/example/t2i/multigen/doll.jpg new file mode 100644 index 0000000000000000000000000000000000000000..755591e2ecf1b400ed5430edf2ce2c4b6ceb0421 Binary files /dev/null and b/condition/example/t2i/multigen/doll.jpg differ diff --git a/condition/example/t2i/multigen/girl.jpg b/condition/example/t2i/multigen/girl.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d698a5f23a14d2b3990884f22f9fb535dfc341a9 Binary files /dev/null and b/condition/example/t2i/multigen/girl.jpg differ diff --git a/condition/example/t2i/multigen/house.jpg b/condition/example/t2i/multigen/house.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e5ca84748371c17fa6f94944ffb73ea54a5f98e2 Binary files /dev/null and b/condition/example/t2i/multigen/house.jpg differ diff --git a/condition/example/t2i/multigen/sofa.png b/condition/example/t2i/multigen/sofa.png new file mode 100644 index 0000000000000000000000000000000000000000..de085957c27a0202cee66d7492fc7bbee454cb0b Binary files /dev/null and b/condition/example/t2i/multigen/sofa.png differ diff --git a/condition/hed.py b/condition/hed.py new file mode 100644 index 0000000000000000000000000000000000000000..10dab658ab139ae30feeabe7acf226023185a873 --- /dev/null +++ b/condition/hed.py @@ -0,0 +1,117 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import cv2 +import torch +import numpy as np +from torch.nn.parallel import DataParallel +from einops import rearrange +from condition.utils import annotator_ckpts_path +import torch.nn.functional as F + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + + +class HEDdetector(torch.nn.Module): + def __init__(self): + super().__init__() + remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth" + modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + self.netNetwork = ControlNetHED_Apache2().float()#.to(self.device).eval() + self.netNetwork.load_state_dict(torch.load(modelpath)) + + def __call__(self, input_image): + """ + input: tensor (B,C,H,W) + output: tensor (B,H,W) + """ + B, C, H, W = input_image.shape + image_hed = input_image + + edges = self.netNetwork(image_hed) + edges = [F.interpolate(e, size=(H, W), mode='bilinear', align_corners=False).squeeze(1) for e in edges] + edges = torch.stack(edges, dim=1) + edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1))) + edge = (edge * 255.0).clamp(0, 255) + + return edge + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from tqdm import tqdm + import torch.nn.functional as F + device = torch.device('cuda') + apply_hed = HEDdetector().to(device).eval() + img = cv2.imread('condition/dragon_1024_512.jpg') + H,W = img.shape[:2] + resize_img = cv2.resize(img,(512,1024)) + detected_map = apply_hed(torch.from_numpy(img).permute(2,0,1).unsqueeze(0).cuda()) + resize_detected_map = apply_hed(torch.from_numpy(resize_img).permute(2,0,1).unsqueeze(0).cuda()) + cv2.imwrite('condition/example_hed_resize.jpg', resize_detected_map[0].cpu().detach().numpy()) + resize_detected_map = F.interpolate(resize_detected_map.unsqueeze(0).to(torch.float32), size=(H,W), mode='bilinear', align_corners=False, antialias=True) + print(abs(detected_map - resize_detected_map).sum()) + print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min()) + cv2.imwrite('condition/example_hed.jpg', detected_map[0].cpu().detach().numpy()) + cv2.imwrite('condition/example_hed_resized.jpg', resize_detected_map[0,0].cpu().detach().numpy()) \ No newline at end of file diff --git a/condition/lineart.py b/condition/lineart.py new file mode 100644 index 0000000000000000000000000000000000000000..8d79c5e8ea695a7d4a88e469fce316eee561e7d9 --- /dev/null +++ b/condition/lineart.py @@ -0,0 +1,98 @@ +from controlnet_aux import LineartDetector +import torch +import cv2 +import numpy as np +import torch.nn as nn + + +norm_layer = nn.InstanceNorm2d +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [ nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) +class LineArt(nn.Module): + def __init__(self, input_nc=3, output_nc=1, n_residual_blocks=3, sigmoid=True): + super(LineArt, self).__init__() + + # Initial convolution block + model0 = [ nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True) ] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features*2 + for _ in range(2): + model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features*2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features//2 + for _ in range(2): + model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features//2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [ nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + """ + input: tensor (B,C,H,W) + output: tensor (B,1,H,W) 0~1 + """ + + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from tqdm import tqdm + apply_lineart = LineArt() + apply_lineart.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu'))) + img = cv2.imread('condition/car_448_768.jpg') + img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).repeat(8,1,1,1).float() + detected_map = apply_lineart(img) + print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min()) + cv2.imwrite('condition/example_lineart.jpg', 255*detected_map[0,0].cpu().detach().numpy()) \ No newline at end of file diff --git a/condition/midas/depth.py b/condition/midas/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbf51d698facc991691e477b1f7088f10299cee --- /dev/null +++ b/condition/midas/depth.py @@ -0,0 +1,223 @@ +# Midas Depth Estimation +# From https://github.com/isl-org/MiDaS +# MIT LICENSE + +import cv2 +import numpy as np +import torch +import os +import sys +current_directory = os.getcwd() +sys.path.append(current_directory) +from einops import rearrange +# from .api import MiDaSInference +from condition.utils import annotator_ckpts_path +from condition.midas.midas.dpt_depth import DPTDepthModel +from condition.midas.midas.midas_net import MidasNet +from condition.midas.midas.midas_net_custom import MidasNet_small +from condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet +import os +import torch.nn as nn +from torchvision.transforms import Compose + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + + +class MidasDetector: + def __init__(self,device=torch.device('cuda:0'), model_type="dpt_hybrid"): + self.device = device + self.model = MiDaSInference(model_type=model_type).to(device) + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = image_depth + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + # normal = np.stack([x, y, z], axis=2) + # normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + # normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image#, normal_image + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from tqdm import tqdm + from PIL import Image + import torchvision.transforms.functional as F + apply_depth = MidasDetector(device=torch.device('cuda:0')) + img = cv2.imread('/data/vjuicefs_sz_cv_v2/11171709/ControlAR_github/condition/example/t2i/multi_resolution/car_1_448_768.jpg') + img = cv2.resize(img,(768,448)) + detected_map = apply_depth(torch.from_numpy(img).cuda().float()) + print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min()) + plt.imshow(detected_map, cmap='gray') + plt.show() + cv2.imwrite('condition/example_depth.jpg', detected_map) + # cv2.imwrite('condition/example_normal.jpg', normal_map) diff --git a/condition/midas/midas/__init__.py b/condition/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/condition/midas/midas/base_model.py b/condition/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a35b20deca1fb7b5778c95675a7aa538ee1d733d --- /dev/null +++ b/condition/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) \ No newline at end of file diff --git a/condition/midas/midas/blocks.py b/condition/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..ba735d6c718e018a26b1d558bad358ee606280bb --- /dev/null +++ b/condition/midas/midas/blocks.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output \ No newline at end of file diff --git a/condition/midas/midas/dpt_depth.py b/condition/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa3a562fa0751967f54a6d734301602de39ebad --- /dev/null +++ b/condition/midas/midas/dpt_depth.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) \ No newline at end of file diff --git a/condition/midas/midas/midas_net.py b/condition/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..4f95fcd8aabd574f178ad4515b8cdcc8979281d8 --- /dev/null +++ b/condition/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) \ No newline at end of file diff --git a/condition/midas/midas/midas_net_custom.py b/condition/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/condition/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/condition/midas/midas/transforms.py b/condition/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6f0d15e8fa8f10804b7bb3bda6bdfbf1daef6506 --- /dev/null +++ b/condition/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample \ No newline at end of file diff --git a/condition/midas/midas/vit.py b/condition/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..75671d64f534b5bed047c9374b5fc3ccf0e15d23 --- /dev/null +++ b/condition/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) \ No newline at end of file diff --git a/condition/utils.py b/condition/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..00e73236557639f67a25f102f11f5bbd4bbe4515 --- /dev/null +++ b/condition/utils.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img \ No newline at end of file diff --git a/language/README.md b/language/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2c252e41b784e6377d945697a243f0a75ce9079a --- /dev/null +++ b/language/README.md @@ -0,0 +1,14 @@ +## Language models for text-conditional image generation + +### Requirements +``` +pip install ftfy +pip install transformers +pip install accelerate +pip install sentencepiece +pip install pandas +pip install bs4 +``` + +### Language Models +Download flan-t5-xl models from [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) and put into the folder of `./pretrained_models/t5-ckpt/` diff --git a/language/extract_t5_feature.py b/language/extract_t5_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..8980dca9cbdc689db5ab01dfc8d242a8f8cdde6d --- /dev/null +++ b/language/extract_t5_feature.py @@ -0,0 +1,129 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +import numpy as np +import argparse +import os +import json + +from utils.distributed import init_distributed_mode +from language.t5 import T5Embedder + +CAPTION_KEY = { + 'blip': 0, + 'llava': 1, + 'llava_first': 2, +} +################################################################################# +# Training Helper Functions # +################################################################################# +class CustomDataset(Dataset): + def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False): + img_path_list = [] + for lst_name in sorted(os.listdir(lst_dir))[start: end+1]: + if not lst_name.endswith('.jsonl'): + continue + file_path = os.path.join(lst_dir, lst_name) + with open(file_path, 'r') as file: + for line_idx, line in enumerate(file): + data = json.loads(line) + # caption = data[caption_key] + caption = data['text'][CAPTION_KEY[caption_key]] + code_dir = file_path.split('/')[-1].split('.')[0] + if trunc_caption: + caption = caption.split('.')[0] + img_path_list.append((caption, code_dir, line_idx)) + self.img_path_list = img_path_list + + def __len__(self): + return len(self.img_path_list) + + def __getitem__(self, index): + caption, code_dir, code_name = self.img_path_list[index] + return caption, code_dir, code_name + + + +################################################################################# +# Training Loop # +################################################################################# +def main(args): + """ + Trains a new DiT model. + """ + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + # Setup DDP: + # dist.init_process_group("nccl") + init_distributed_mode(args) + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # Setup a feature folder: + if rank == 0: + os.makedirs(args.t5_path, exist_ok=True) + + # Setup data: + print(f"Dataset is preparing...") + dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption) + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=1, # important! + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False + ) + print(f"Dataset contains {len(dataset):,} images") + + precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] + assert os.path.exists(args.t5_model_path) + t5_xxl = T5Embedder( + device=device, + local_cache=True, + cache_dir=args.t5_model_path, + dir_or_name=args.t5_model_type, + torch_dtype=precision + ) + + for caption, code_dir, code_name in loader: + caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption) + valid_caption_embs = caption_embs[:, :emb_masks.sum()] + x = valid_caption_embs.to(torch.float32).detach().cpu().numpy() + os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True) + np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x) + print(code_name.item()) + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--t5-path", type=str, required=True) + parser.add_argument("--data-start", type=int, required=True) + parser.add_argument("--data-end", type=int, required=True) + parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys())) + parser.add_argument("--trunc-caption", action='store_true', default=False) + parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt') + parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') + parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=24) + args = parser.parse_args() + main(args) diff --git a/language/t5.py b/language/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..9a548cdd3bf331632c9dd8c0a078799e3e50ee21 --- /dev/null +++ b/language/t5.py @@ -0,0 +1,201 @@ +# Modified from: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/t5.py +import os +import re +import html +import urllib.parse as ul + +import ftfy +import torch +from bs4 import BeautifulSoup +from transformers import T5EncoderModel, AutoTokenizer +from huggingface_hub import hf_hub_download + + +class T5Embedder: + available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl'] + bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + + def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, + t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + if t5_model_kwargs is None: + t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} + t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') + self.dir_or_name = dir_or_name + tokenizer_path, path = dir_or_name, dir_or_name + if local_cache: + cache_dir = os.path.join(self.cache_dir, dir_or_name) + tokenizer_path, path = cache_dir, cache_dir + elif dir_or_name in self.available_models: + cache_dir = os.path.join(self.cache_dir, dir_or_name) + for filename in [ + 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' + ]: + hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + tokenizer_path, path = cache_dir, cache_dir + else: + cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') + for filename in [ + 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + ]: + hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + tokenizer_path = cache_dir + + print(tokenizer_path) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + texts = [self.text_preprocessing(text) for text in texts] + + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + + text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] + text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] + + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=text_tokens_and_mask['input_ids'].to(self.device), + attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), + )['last_hidden_state'].detach() + return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) + + def text_preprocessing(self, text): + if self.use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = self.clean_caption(text) + text = self.clean_caption(text) + return text + else: + return text.lower().strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + # html: + caption = BeautifulSoup(caption, features='html.parser').text + + # @ + caption = re.sub(r'@[\w\d]+\b', '', caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + # кавычки к одному стандарту + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + # " + caption = re.sub(r'"?', '', caption) + # & + caption = re.sub(r'&', '', caption) + + # ip adresses: + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + # article ids: + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + # \n + caption = re.sub(r'\\n', ' ', caption) + + # "#123" + caption = re.sub(r'#\d{1,3}\b', '', caption) + # "#12345.." + caption = re.sub(r'#\d{5,}\b', '', caption) + # "123456.." + caption = re.sub(r'\b\d{6,}\b', '', caption) + # filenames: + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + + # + caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + + caption = self.basic_clean(caption) + + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a82ec36b6e4afb50183a54ea3d388d162dbc69f8 --- /dev/null +++ b/model.py @@ -0,0 +1,242 @@ +import gc +import spaces +from safetensors.torch import load_file +from autoregressive.models.gpt_t2i import GPT_models +from tokenizer.tokenizer_image.vq_model import VQ_models +from language.t5 import T5Embedder +import torch +import numpy as np +import PIL +from PIL import Image +from condition.canny import CannyDetector +import time +from autoregressive.models.generate import generate +from condition.midas.depth import MidasDetector + +models = { + "canny": "checkpoints/t2i/canny_MR.safetensors", + "depth": "checkpoints/t2i/depth_MR.safetensors", +} + + +def resize_image_to_16_multiple(image, condition_type='canny'): + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + # image = Image.open(image_path) + width, height = image.size + + if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32 + new_width = (width + 31) // 32 * 32 + new_height = (height + 31) // 32 * 32 + else: + new_width = (width + 15) // 16 * 16 + new_height = (height + 15) // 16 * 16 + + resized_image = image.resize((new_width, new_height)) + return resized_image + + +class Model: + + def __init__(self): + self.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu") + self.base_model_id = "" + self.task_name = "" + self.vq_model = self.load_vq() + self.t5_model = self.load_t5() + self.gpt_model_canny = self.load_gpt(condition_type='canny') + self.gpt_model_depth = self.load_gpt(condition_type='depth') + self.get_control_canny = CannyDetector() + self.get_control_depth = MidasDetector(device=self.device) + + def load_vq(self): + vq_model = VQ_models["VQ-16"](codebook_size=16384, + codebook_embed_dim=8) + vq_model.to(self.device) + vq_model.eval() + checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt", + map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + del checkpoint + print(f"image tokenizer is loaded") + return vq_model + + def load_gpt(self, condition_type='canny'): + gpt_ckpt = models[condition_type] + precision = torch.bfloat16 + latent_size = 768 // 16 + gpt_model = GPT_models["GPT-XL"]( + block_size=latent_size**2, + cls_token_num=120, + model_type='t2i', + condition_type=condition_type, + ).to(device=self.device, dtype=precision) + + model_weight = load_file(gpt_ckpt) + gpt_model.load_state_dict(model_weight, strict=False) + gpt_model.eval() + print(f"gpt model is loaded") + return gpt_model + + def load_t5(self): + precision = torch.bfloat16 + t5_model = T5Embedder( + device=self.device, + local_cache=True, + # cache_dir='checkpoints/t5-ckpt', + dir_or_name='flan-t5-xl', + torch_dtype=precision, + model_max_length=120, + ) + return t5_model + + @torch.no_grad() + @spaces.GPU(enable_queue=True) + def process_canny( + self, + image: np.ndarray, + prompt: str, + cfg_scale: float, + temperature: float, + top_k: int, + top_p: int, + seed: int, + low_threshold: int, + high_threshold: int, + ) -> list[PIL.Image.Image]: + + image = resize_image_to_16_multiple(image, 'canny') + W, H = image.size + print(W, H) + condition_img = self.get_control_canny(np.array(image), low_threshold, + high_threshold) + condition_img = torch.from_numpy(condition_img[None, None, + ...]).repeat( + 2, 3, 1, 1) + condition_img = condition_img.to(self.device) + condition_img = 2 * (condition_img / 255 - 0.5) + prompts = [prompt] * 2 + caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) + + print(f"processing left-padding...") + new_emb_masks = torch.flip(emb_masks, dims=[-1]) + new_caption_embs = [] + for idx, (caption_emb, + emb_mask) in enumerate(zip(caption_embs, emb_masks)): + valid_num = int(emb_mask.sum().item()) + print(f' prompt {idx} token len: {valid_num}') + new_caption_emb = torch.cat( + [caption_emb[valid_num:], caption_emb[:valid_num]]) + new_caption_embs.append(new_caption_emb) + new_caption_embs = torch.stack(new_caption_embs) + c_indices = new_caption_embs * new_emb_masks[:, :, None] + c_emb_masks = new_emb_masks + qzshape = [len(c_indices), 8, H // 16, W // 16] + t1 = time.time() + index_sample = generate( + self.gpt_model_canny, + c_indices, + (H // 16) * (W // 16), + c_emb_masks, + condition=condition_img, + cfg_scale=cfg_scale, + temperature=temperature, + top_k=top_k, + top_p=top_p, + sample_logits=True, + ) + sampling_time = time.time() - t1 + print(f"Full sampling takes about {sampling_time:.2f} seconds.") + + t2 = time.time() + print(index_sample.shape) + samples = self.vq_model.decode_code( + index_sample, qzshape) # output value is between [-1, 1] + decoder_time = time.time() - t2 + print(f"decoder takes about {decoder_time:.2f} seconds.") + + samples = torch.cat((condition_img[0:1], samples), dim=0) + samples = 255 * (samples * 0.5 + 0.5) + samples = [image] + [ + Image.fromarray( + sample.permute(1, 2, 0).cpu().detach().numpy().clip( + 0, 255).astype(np.uint8)) for sample in samples + ] + del condition_img + torch.cuda.empty_cache() + return samples + + @torch.no_grad() + @spaces.GPU(enable_queue=True) + def process_depth( + self, + image: np.ndarray, + prompt: str, + cfg_scale: float, + temperature: float, + top_k: int, + top_p: int, + seed: int, + ) -> list[PIL.Image.Image]: + image = resize_image_to_16_multiple(image, 'depth') + W, H = image.size + print(W, H) + image_tensor = torch.from_numpy(np.array(image)).to(self.device) + condition_img = torch.from_numpy( + self.get_control_depth(image_tensor)).unsqueeze(0) + condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1) + condition_img = condition_img.to(self.device) + condition_img = 2 * (condition_img / 255 - 0.5) + prompts = [prompt] * 2 + caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) + + print(f"processing left-padding...") + new_emb_masks = torch.flip(emb_masks, dims=[-1]) + new_caption_embs = [] + for idx, (caption_emb, + emb_mask) in enumerate(zip(caption_embs, emb_masks)): + valid_num = int(emb_mask.sum().item()) + print(f' prompt {idx} token len: {valid_num}') + new_caption_emb = torch.cat( + [caption_emb[valid_num:], caption_emb[:valid_num]]) + new_caption_embs.append(new_caption_emb) + new_caption_embs = torch.stack(new_caption_embs) + + c_indices = new_caption_embs * new_emb_masks[:, :, None] + c_emb_masks = new_emb_masks + qzshape = [len(c_indices), 8, H // 16, W // 16] + t1 = time.time() + index_sample = generate( + self.gpt_model_depth, + c_indices, + (H // 16) * (W // 16), + c_emb_masks, + condition=condition_img, + cfg_scale=cfg_scale, + temperature=temperature, + top_k=top_k, + top_p=top_p, + sample_logits=True, + ) + sampling_time = time.time() - t1 + print(f"Full sampling takes about {sampling_time:.2f} seconds.") + + t2 = time.time() + print(index_sample.shape) + samples = self.vq_model.decode_code(index_sample, qzshape) + decoder_time = time.time() - t2 + print(f"decoder takes about {decoder_time:.2f} seconds.") + condition_img = condition_img.cpu() + samples = samples.cpu() + samples = torch.cat((condition_img[0:1], samples), dim=0) + samples = 255 * (samples * 0.5 + 0.5) + samples = [image] + [ + Image.fromarray( + sample.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8)) + for sample in samples + ] + del image_tensor + del condition_img + torch.cuda.empty_cache() + return samples diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..39fafdc10147d039303439d70f168325cfaf31e7 --- /dev/null +++ b/style.css @@ -0,0 +1,10 @@ +h1 { + text-align: center; + } + + #duplicate-button { + margin: auto; + color: #fff; + background: #1565c0; + border-radius: 100vh; + } \ No newline at end of file diff --git a/tokenizer/consistencydecoder/README.md b/tokenizer/consistencydecoder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9565f597f55f9c6c25318fc8b4ea2fddddcbbf6d --- /dev/null +++ b/tokenizer/consistencydecoder/README.md @@ -0,0 +1,14 @@ +## Consistency Decoder from OpenAI + +### install +``` +pip install diffusers +pip install accelerate +``` + +### demo +``` +cd ${THIS_REPO_ROOT} +python3 tokenizer/consistencydecoder/cd_demo.py +``` + diff --git a/tokenizer/consistencydecoder/cd_demo.py b/tokenizer/consistencydecoder/cd_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..2cdb5532b5098d59a7e7baf02c849e92b87fb341 --- /dev/null +++ b/tokenizer/consistencydecoder/cd_demo.py @@ -0,0 +1,57 @@ +import argparse +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image +from diffusers import ConsistencyDecoderVAE + + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # create and load model + vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to(device) + + # load image + img_path = args.image_path + out_path = args.image_path.replace('.jpg', '_cd.jpg').replace('.jpeg', '_cd.jpeg').replace('.png', '_cd.png') + input_size = args.image_size + img = Image.open(img_path).convert("RGB") + + # preprocess + size_org = img.size + img = img.resize((input_size, input_size)) + img = np.array(img) / 255. + x = 2.0 * img - 1.0 # x value is between [-1, 1] + x = torch.tensor(x) + x = x.unsqueeze(dim=0) + x = torch.einsum('nhwc->nchw', x) + x_input = x.half().to(device) + + # inference + with torch.no_grad(): + # Map input images to latent space + normalize latents: + latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215) + # reconstruct: + output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] + + # postprocess + output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0] + sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() + + # save + Image.fromarray(sample).save(out_path) + print("Reconstructed image is saved to {}".format(out_path)) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image-path", type=str, default="assets/example.jpg") + parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + main(args) diff --git a/tokenizer/consistencydecoder/reconstruction_cd_ddp.py b/tokenizer/consistencydecoder/reconstruction_cd_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..c7bf7c07e94c65125211a9f2464907e7ff162e2b --- /dev/null +++ b/tokenizer/consistencydecoder/reconstruction_cd_ddp.py @@ -0,0 +1,208 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +from tqdm import tqdm +import os +import itertools +from PIL import Image +import numpy as np +import argparse +import random + +from skimage.metrics import peak_signal_noise_ratio as psnr_loss +from skimage.metrics import structural_similarity as ssim_loss +from diffusers.models import ConsistencyDecoderVAE + + +class SingleFolderDataset(Dataset): + def __init__(self, directory, transform=None): + super().__init__() + self.directory = directory + self.transform = transform + self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) + if os.path.isfile(os.path.join(directory, file_name))] + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image_path = self.image_paths[idx] + image = Image.open(image_path).convert('RGB') + if self.transform: + image = self.transform(image) + return image, torch.tensor(0) + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + + random.shuffle(samples) # This is very important for IS(Inception Score) !!! + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup env + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # create and load model + vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device)) + + # Create folder to save samples: + folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + if args.dataset == 'imagenet': + dataset = ImageFolder(args.data_path, transform=transform) + num_fid_samples = 50000 + elif args.dataset == 'coco': + dataset = SingleFolderDataset(args.data_path, transform=transform) + num_fid_samples = 5000 + else: + raise Exception("please check dataset") + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=args.per_proc_batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False + ) + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + psnr_val_rgb = [] + ssim_val_rgb = [] + + loader = tqdm(loader) if rank == 0 else loader + total = 0 + for x, _ in loader: + rgb_gts = x + rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] + x = x.half().to("cuda:{}".format(device)) + with torch.no_grad(): + # Map input images to latent space + normalize latents: + latent = vae.encode(x).latent_dist.sample().mul_(0.18215) + # reconstruct: + samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + # metric + rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] + psnr = psnr_loss(rgb_restored, rgb_gt) + ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) + psnr_val_rgb.append(psnr) + ssim_val_rgb.append(ssim) + total += global_batch_size + + # ------------------------------------ + # Summary + # ------------------------------------ + # Make sure all processes have finished saving their samples + dist.barrier() + world_size = dist.get_world_size() + gather_psnr_val = [None for _ in range(world_size)] + gather_ssim_val = [None for _ in range(world_size)] + dist.all_gather_object(gather_psnr_val, psnr_val_rgb) + dist.all_gather_object(gather_ssim_val, ssim_val_rgb) + + if rank == 0: + gather_psnr_val = list(itertools.chain(*gather_psnr_val)) + gather_ssim_val = list(itertools.chain(*gather_ssim_val)) + psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) + ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) + + result_file = f"{sample_folder_dir}_results.txt" + print("writing results to {}".format(result_file)) + with open(result_file, 'w') as f: + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) + + create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) + print("Done.") + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--sample-dir", type=str, default="reconstructions") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=4) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tokenizer/tokenizer_image/cache/vgg.pth b/tokenizer/tokenizer_image/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/tokenizer/tokenizer_image/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/tokenizer/tokenizer_image/discriminator.py b/tokenizer/tokenizer_image/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..35110c1518690d67e7319b24d57a51d7e11f6021 --- /dev/null +++ b/tokenizer/tokenizer_image/discriminator.py @@ -0,0 +1,255 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py +# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py +import functools +import math +import torch +import torch.nn as nn +try: + from kornia.filters import filter2d +except: + pass + +################################################################################# +# PatchGAN # +################################################################################# +class PatchGANDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(PatchGANDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.normal_(module.weight.data, 0.0, 0.02) + elif isinstance(module, nn.BatchNorm2d): + nn.init.normal_(module.weight.data, 1.0, 0.02) + nn.init.constant_(module.bias.data, 0) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + + +################################################################################# +# StyleGAN # +################################################################################# +class StyleGANDiscriminator(nn.Module): + def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256): + super().__init__() + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + log_size = int(math.log(image_size, 2)) + in_channel = channels[image_size] + + blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + blocks.append(DiscriminatorBlock(in_channel, out_channel)) + in_channel = out_channel + self.blocks = nn.ModuleList(blocks) + + self.final_conv = nn.Sequential( + nn.Conv2d(in_channel, channels[4], 3, padding=1), + leaky_relu(), + ) + self.final_linear = nn.Sequential( + nn.Linear(channels[4] * 4 * 4, channels[4]), + leaky_relu(), + nn.Linear(channels[4], 1) + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.final_conv(x) + x = x.view(x.shape[0], -1) + x = self.final_linear(x) + return x + + +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1)) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu() + ) + + self.downsample = nn.Sequential( + Blur(), + nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) + ) if downsample else None + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + if exists(self.downsample): + x = self.downsample(x) + x = (x + res) * (1 / math.sqrt(2)) + return x + + +class Blur(nn.Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer('f', f) + + def forward(self, x): + f = self.f + f = f[None, None, :] * f [None, :, None] + return filter2d(x, f, normalized=True) + + +def leaky_relu(p=0.2): + return nn.LeakyReLU(p, inplace=True) + + +def exists(val): + return val is not None \ No newline at end of file diff --git a/tokenizer/tokenizer_image/discriminator_patchgan.py b/tokenizer/tokenizer_image/discriminator_patchgan.py new file mode 100644 index 0000000000000000000000000000000000000000..aeab654c4533b287852c07e01e3995f3f8b8067a --- /dev/null +++ b/tokenizer/tokenizer_image/discriminator_patchgan.py @@ -0,0 +1,152 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +import functools +import torch +import torch.nn as nn + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.normal_(module.weight.data, 0.0, 0.02) + elif isinstance(module, nn.BatchNorm2d): + nn.init.normal_(module.weight.data, 1.0, 0.02) + nn.init.constant_(module.bias.data, 0) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h \ No newline at end of file diff --git a/tokenizer/tokenizer_image/discriminator_stylegan.py b/tokenizer/tokenizer_image/discriminator_stylegan.py new file mode 100644 index 0000000000000000000000000000000000000000..17c4b50d2ee4c198cb188ff4f3582cb9ef145073 --- /dev/null +++ b/tokenizer/tokenizer_image/discriminator_stylegan.py @@ -0,0 +1,101 @@ +# Modified from: +# stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py +# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py +# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py +import math +import torch +import torch.nn as nn +try: + from kornia.filters import filter2d +except: + pass + +class Discriminator(nn.Module): + def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256): + super().__init__() + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + log_size = int(math.log(image_size, 2)) + in_channel = channels[image_size] + + blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + blocks.append(DiscriminatorBlock(in_channel, out_channel)) + in_channel = out_channel + self.blocks = nn.ModuleList(blocks) + + self.final_conv = nn.Sequential( + nn.Conv2d(in_channel, channels[4], 3, padding=1), + leaky_relu(), + ) + self.final_linear = nn.Sequential( + nn.Linear(channels[4] * 4 * 4, channels[4]), + leaky_relu(), + nn.Linear(channels[4], 1) + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.final_conv(x) + x = x.view(x.shape[0], -1) + x = self.final_linear(x) + return x + + +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1)) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu() + ) + + self.downsample = nn.Sequential( + Blur(), + nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) + ) if downsample else None + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + if exists(self.downsample): + x = self.downsample(x) + x = (x + res) * (1 / math.sqrt(2)) + return x + + + +class Blur(nn.Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer('f', f) + + def forward(self, x): + f = self.f + f = f[None, None, :] * f [None, :, None] + return filter2d(x, f, normalized=True) + + +def leaky_relu(p=0.2): + return nn.LeakyReLU(p, inplace=True) + + +def exists(val): + return val is not None diff --git a/tokenizer/tokenizer_image/lpips.py b/tokenizer/tokenizer_image/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..73fb28f05ed9bb946dc6611670b921cf264223aa --- /dev/null +++ b/tokenizer/tokenizer_image/lpips.py @@ -0,0 +1,164 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import os, hashlib +import requests +from tqdm import tqdm + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) \ No newline at end of file diff --git a/tokenizer/tokenizer_image/reconstruction_vq_ddp.py b/tokenizer/tokenizer_image/reconstruction_vq_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..93087ae818758b4676489ccb1543cf3168701eb3 --- /dev/null +++ b/tokenizer/tokenizer_image/reconstruction_vq_ddp.py @@ -0,0 +1,207 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.nn.functional as F +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision import transforms +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import argparse +import itertools + +from skimage.metrics import peak_signal_noise_ratio as psnr_loss +from skimage.metrics import structural_similarity as ssim_loss +from dataset.augmentation import center_crop_arr +from dataset.build import build_dataset +from tokenizer.tokenizer_image.vq_model import VQ_models + + + +def create_npz_from_sample_folder(sample_dir, num=50000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + vq_model.to(device) + vq_model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + if "ema" in checkpoint: # ema + model_weight = checkpoint["ema"] + elif "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + vq_model.load_state_dict(model_weight) + del checkpoint + + # Create folder to save samples: + folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}" + f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}") + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + + if args.dataset == 'imagenet': + dataset = build_dataset(args, transform=transform) + num_fid_samples = 50000 + elif args.dataset == 'coco': + dataset = build_dataset(args, transform=transform) + num_fid_samples = 5000 + elif args.dataset == 'imagenet_code': + dataset = build_dataset(args) + num_fid_samples = 50000 + else: + raise Exception("please check dataset") + + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=args.per_proc_batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False + ) + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + + psnr_val_rgb = [] + ssim_val_rgb = [] + loader = tqdm(loader) if rank == 0 else loader + total = 0 + # for x, _ in loader: + for batch in loader: + x = batch['condition_imgs'].repeat(1,3,1,1) + # import pdb + # pdb.set_trace() + if args.image_size_eval != args.image_size: + rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic') + else: + rgb_gts = x + rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] + x = x.to(device, non_blocking=True) + with torch.no_grad(): + latent, _, [_, _, indices] = vq_model.encode(x.float()) + import pdb;pdb.set_trace() + samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1] + if args.image_size_eval != args.image_size: + samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic') + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): + index = i * dist.get_world_size() + rank + total + # Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + # metric + rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] + psnr = psnr_loss(rgb_restored, rgb_gt) + ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) + psnr_val_rgb.append(psnr) + ssim_val_rgb.append(ssim) + + total += global_batch_size + + # ------------------------------------ + # Summary + # ------------------------------------ + # Make sure all processes have finished saving their samples + dist.barrier() + world_size = dist.get_world_size() + gather_psnr_val = [None for _ in range(world_size)] + gather_ssim_val = [None for _ in range(world_size)] + dist.all_gather_object(gather_psnr_val, psnr_val_rgb) + dist.all_gather_object(gather_ssim_val, ssim_val_rgb) + + if rank == 0: + gather_psnr_val = list(itertools.chain(*gather_psnr_val)) + gather_ssim_val = list(itertools.chain(*gather_ssim_val)) + psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) + ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) + + result_file = f"{sample_folder_dir}_results.txt" + print("writing results to {}".format(result_file)) + with open(result_file, 'w') as f: + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) + + create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) + print("Done.") + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default=None) + parser.add_argument("--code-path", type=str, required=True) + parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco', 'imagenet_code'], default='imagenet') + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256) + parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256) + parser.add_argument("--sample-dir", type=str, default="reconstructions") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--condition", type=str, choices=['canny', 'hed'], default='canny') + parser.add_argument("--get-condition-img", type=bool, default=False) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tokenizer/tokenizer_image/vq_demo.py b/tokenizer/tokenizer_image/vq_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..87ec958f631659e6f03953d5c9ba88875c55846c --- /dev/null +++ b/tokenizer/tokenizer_image/vq_demo.py @@ -0,0 +1,84 @@ +import torch +import torch.nn.functional as F + +import os +import argparse +import numpy as np +from PIL import Image + +from tokenizer.tokenizer_image.vq_model import VQ_models +from dataset.augmentation import center_crop_arr + + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # create and load model + model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + model.to(device) + model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + if "ema" in checkpoint: # ema + model_weight = checkpoint["ema"] + elif "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + model.load_state_dict(model_weight) + del checkpoint + + # output dir + os.makedirs(args.output_dir, exist_ok=True) + out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix)) + out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix)) + out_path = out_path.replace('.png', '_{}.png'.format(args.suffix)) + out_filename = out_path.split('/')[-1] + out_path = os.path.join(args.output_dir, out_filename) + + # load image + pil_image = Image.open(args.image_path).convert("RGB") + img = center_crop_arr(pil_image, args.image_size) + # # preprocess + # size_org = img.size + # img = img.resize((input_size, input_size)) + img = np.array(img) / 255. + x = 2.0 * img - 1.0 # x value is between [-1, 1] + x = torch.tensor(x) + x = x.unsqueeze(dim=0) + x = torch.einsum('nhwc->nchw', x) + x_input = x.float().to("cuda") + + # inference + with torch.no_grad(): + latent, _, [_, _, indices] = model.encode(x_input) + output = model.decode_code(indices, latent.shape) # output value is between [-1, 1] + + # postprocess + output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0] + sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() + + # save + Image.fromarray(sample).save(out_path) + print("Reconstructed image is saved to {}".format(out_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image-path", type=str, default="assets/example.jpg") + parser.add_argument("--output-dir", type=str, default="output_vq_demo") + parser.add_argument("--suffix", type=str, default="tokenizer_image") + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tokenizer/tokenizer_image/vq_loss.py b/tokenizer/tokenizer_image/vq_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fe386de7fea05dc1ab86cc5da0e3393edfd678c4 --- /dev/null +++ b/tokenizer/tokenizer_image/vq_loss.py @@ -0,0 +1,168 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tokenizer.tokenizer_image.lpips import LPIPS +from tokenizer.tokenizer_image.discriminator_patchgan import NLayerDiscriminator as PatchGANDiscriminator +from tokenizer.tokenizer_image.discriminator_stylegan import Discriminator as StyleGANDiscriminator + + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.softplus(-logits_real)) + loss_fake = torch.mean(F.softplus(logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def non_saturating_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real)) + loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def hinge_gen_loss(logit_fake): + return -torch.mean(logit_fake) + + +def non_saturating_gen_loss(logit_fake): + return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake)) + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +class VQLoss(nn.Module): + def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256, + disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False, + gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0, + codebook_weight=1.0, perceptual_weight=1.0, + ): + super().__init__() + # discriminator loss + assert disc_type in ["patchgan", "stylegan"] + assert disc_loss in ["hinge", "vanilla", "non-saturating"] + if disc_type == "patchgan": + self.discriminator = PatchGANDiscriminator( + input_nc=disc_in_channels, + n_layers=disc_num_layers, + ndf=disc_dim, + ) + elif disc_type == "stylegan": + self.discriminator = StyleGANDiscriminator( + input_nc=disc_in_channels, + image_size=image_size, + ) + else: + raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.") + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + elif disc_loss == "non-saturating": + self.disc_loss = non_saturating_d_loss + else: + raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.") + self.discriminator_iter_start = disc_start + self.disc_weight = disc_weight + self.disc_adaptive_weight = disc_adaptive_weight + + assert gen_adv_loss in ["hinge", "non-saturating"] + # gen_adv_loss + if gen_adv_loss == "hinge": + self.gen_adv_loss = hinge_gen_loss + elif gen_adv_loss == "non-saturating": + self.gen_adv_loss = non_saturating_gen_loss + else: + raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.") + + # perceptual loss + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + + # reconstruction loss + if reconstruction_loss == "l1": + self.rec_loss = F.l1_loss + elif reconstruction_loss == "l2": + self.rec_loss = F.mse_loss + else: + raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.") + self.rec_weight = reconstruction_weight + + # codebook loss + self.codebook_weight = codebook_weight + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + return d_weight.detach() + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, + logger=None, log_every=100): + # generator update + if optimizer_idx == 0: + # reconstruction loss + rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous()) + + # perceptual loss + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + p_loss = torch.mean(p_loss) + + # discriminator loss + logits_fake = self.discriminator(reconstructions.contiguous()) + generator_adv_loss = self.gen_adv_loss(logits_fake) + + if self.disc_adaptive_weight: + null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss + disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer) + else: + disc_adaptive_weight = 1 + disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start) + + loss = self.rec_weight * rec_loss + \ + self.perceptual_weight * p_loss + \ + disc_adaptive_weight * disc_weight * generator_adv_loss + \ + codebook_loss[0] + codebook_loss[1] + codebook_loss[2] + + if global_step % log_every == 0: + rec_loss = self.rec_weight * rec_loss + p_loss = self.perceptual_weight * p_loss + generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss + logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, " + f"vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, " + f"codebook_usage: {codebook_loss[3]:.4f}, generator_adv_loss: {generator_adv_loss:.4f}, " + f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}") + return loss + + # discriminator update + if optimizer_idx == 1: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start) + d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake) + + if global_step % log_every == 0: + logits_real = logits_real.detach().mean() + logits_fake = logits_fake.detach().mean() + logger.info(f"(Discriminator) " + f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, " + f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}") + return d_adversarial_loss \ No newline at end of file diff --git a/tokenizer/tokenizer_image/vq_model.py b/tokenizer/tokenizer_image/vq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c0185f0983e86e334c3feee34802199f77352a03 --- /dev/null +++ b/tokenizer/tokenizer_image/vq_model.py @@ -0,0 +1,425 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# maskgit: https://github.com/google-research/maskgit +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p) + self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p) + + self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim, + config.commit_loss_beta, config.entropy_loss_ratio, + config.codebook_l2_norm, config.codebook_show_usage) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1) + + def encode(self, x): + #import pdb; pdb.set_trace() + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + + +class Encoder(nn.Module): + def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, + norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type)) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions-1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + + +class Decoder(nn.Module): + def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group", + dropout=0.0, resamp_with_conv=True, out_channels=3): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch*ch_mult[self.num_resolutions-1] + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.ModuleList() + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type)) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + if self.l2_norm: + self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1) + if self.show_usage: + self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum('b c h w -> b h w c', z).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(embedding**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding)) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + codebook_usage = 0 + + if self.show_usage and self.training: + cur_len = min_encoding_indices.shape[0] + self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() + self.codebook_used[-cur_len:] = min_encoding_indices + codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e + + # compute loss for embedding + if self.training: + vq_loss = torch.mean((z_q - z.detach()) ** 2) + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum('b h w c -> b c h w', z_q) + + return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type='group'): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, norm_type='group'): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_8(**kwargs): + return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs)) + +def VQ_16(**kwargs): + return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs)) + +VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8} \ No newline at end of file diff --git a/tokenizer/tokenizer_image/vq_model_hf.py b/tokenizer/tokenizer_image/vq_model_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..70277ed6f536d28f41319361aadef30baa957cab --- /dev/null +++ b/tokenizer/tokenizer_image/vq_model_hf.py @@ -0,0 +1,17 @@ +from huggingface_hub import PyTorchModelHubMixin + +from tokenizer.tokenizer_image.vq_model import ModelArgs, VQModel + +class VQModelHF(VQModel, PyTorchModelHubMixin, repo_url="https://github.com/FoundationVision/LlamaGen", license="mit", tags=["llamagen", "text-to-image"]): + pass + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_8(**kwargs): + return VQModelHF(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs)) + +def VQ_16(**kwargs): + return VQModelHF(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs)) + +VQ_models_HF = {'VQ-16': VQ_16, 'VQ-8': VQ_8} diff --git a/tokenizer/tokenizer_image/vq_train.py b/tokenizer/tokenizer_image/vq_train.py new file mode 100644 index 0000000000000000000000000000000000000000..db545ddf139dcaff031a1acc9665908aa38b0447 --- /dev/null +++ b/tokenizer/tokenizer_image/vq_train.py @@ -0,0 +1,323 @@ +# Modified from: +# fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py +# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py +import torch +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms + +import os +import time +import argparse +from glob import glob +from copy import deepcopy +# import sys +# sys.path.append('/data/vjuicefs_sz_cv_v2/11171709/ControlAR') +from utils.logger import create_logger +from utils.distributed import init_distributed_mode +from utils.ema import update_ema, requires_grad +from dataset.augmentation import random_crop_arr +from dataset.build import build_dataset +from tokenizer.tokenizer_image.vq_model import VQ_models +from tokenizer.tokenizer_image.vq_loss import VQLoss + +import warnings +warnings.filterwarnings('ignore') + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + """ + Trains a new model. + """ + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + # Setup DDP: + init_distributed_mode(args) + assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + + # Setup an experiment folder: + if rank == 0: + os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) + experiment_index = len(glob(f"{args.results_dir}/*")) + model_string_name = args.vq_model.replace("/", "-") + experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder + checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints + os.makedirs(checkpoint_dir, exist_ok=True) + logger = create_logger(experiment_dir) + logger.info(f"Experiment directory created at {experiment_dir}") + + time_record = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + cloud_results_dir = f"{args.cloud_save_path}/{time_record}" + cloud_checkpoint_dir = f"{cloud_results_dir}/{experiment_index:03d}-{model_string_name}/checkpoints" + os.makedirs(cloud_checkpoint_dir, exist_ok=True) + logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}") + + else: + logger = create_logger(None) + + # training args + logger.info(f"{args}") + + # training env + logger.info(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim, + commit_loss_beta=args.commit_loss_beta, + entropy_loss_ratio=args.entropy_loss_ratio, + dropout_p=args.dropout_p, + ) + logger.info(f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}") + if args.ema: + ema = deepcopy(vq_model).to(device) # Create an EMA of the model for use after training + requires_grad(ema, False) + logger.info(f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}") + vq_model = vq_model.to(device) + + vq_loss = VQLoss( + disc_start=args.disc_start, + disc_weight=args.disc_weight, + disc_type=args.disc_type, + disc_loss=args.disc_loss, + gen_adv_loss=args.gen_loss, + image_size=args.image_size, + perceptual_weight=args.perceptual_weight, + reconstruction_weight=args.reconstruction_weight, + reconstruction_loss=args.reconstruction_loss, + codebook_weight=args.codebook_weight, + ).to(device) + logger.info(f"Discriminator Parameters: {sum(p.numel() for p in vq_loss.discriminator.parameters()):,}") + + # initialize a GradScaler. If enabled=False scaler is a no-op + scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16')) + scaler_disc = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16')) + # Setup optimizer + optimizer = torch.optim.Adam(vq_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) + optimizer_disc = torch.optim.Adam(vq_loss.discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: random_crop_arr(pil_image, args.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + if args.dataset == 'imagenet_code': + dataset = build_dataset(args) + else: + dataset = build_dataset(args, transform=transform) + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=True, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=int(args.global_batch_size // dist.get_world_size()), + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True + ) + logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") + + + # Prepare models for training: + if args.vq_ckpt: + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + if args.ema: + ema.load_state_dict(checkpoint["ema"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + vq_loss.discriminator.load_state_dict(checkpoint["discriminator"]) + optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) + if not args.finetune: + train_steps = checkpoint["steps"] if "steps" in checkpoint else int(args.vq_ckpt.split('/')[-1].split('.')[0]) + start_epoch = int(train_steps / int(len(dataset) / args.global_batch_size)) + train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size)) + else: + train_steps = 0 + start_epoch = 0 + del checkpoint + logger.info(f"Resume training from checkpoint: {args.vq_ckpt}") + logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}") + else: + train_steps = 0 + start_epoch = 0 + if args.ema: + update_ema(ema, vq_model, decay=0) # Ensure EMA is initialized with synced weights + + if args.compile: + logger.info("compiling the model... (may take several minutes)") + vq_model = torch.compile(vq_model) # requires PyTorch 2.0 + + vq_model = DDP(vq_model.to(device), device_ids=[args.gpu]) + vq_model.train() + if args.ema: + ema.eval() # EMA model should always be in eval mode + vq_loss = DDP(vq_loss.to(device), device_ids=[args.gpu]) + vq_loss.train() + + ptdtype = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.mixed_precision] + + # Variables for monitoring/logging purposes: + log_steps = 0 + running_loss = 0 + start_time = time.time() + + logger.info(f"Training for {args.epochs} epochs...") + for epoch in range(start_epoch, args.epochs): + sampler.set_epoch(epoch) + logger.info(f"Beginning epoch {epoch}...") + for x, y in loader: + imgs = x.to(device, non_blocking=True) + + # generator training + optimizer.zero_grad() + with torch.cuda.amp.autocast(dtype=ptdtype): + recons_imgs, codebook_loss = vq_model(imgs) + loss_gen = vq_loss(codebook_loss, imgs, recons_imgs, optimizer_idx=0, global_step=train_steps+1, + last_layer=vq_model.module.decoder.last_layer, + logger=logger, log_every=args.log_every) + scaler.scale(loss_gen).backward() + if args.max_grad_norm != 0.0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(vq_model.parameters(), args.max_grad_norm) + scaler.step(optimizer) + scaler.update() + if args.ema: + update_ema(ema, vq_model.module._orig_mod if args.compile else vq_model.module) + + # discriminator training + optimizer_disc.zero_grad() + with torch.cuda.amp.autocast(dtype=ptdtype): + loss_disc = vq_loss(codebook_loss, imgs, recons_imgs, optimizer_idx=1, global_step=train_steps+1, + logger=logger, log_every=args.log_every) + scaler_disc.scale(loss_disc).backward() + if args.max_grad_norm != 0.0: + scaler_disc.unscale_(optimizer_disc) + torch.nn.utils.clip_grad_norm_(vq_loss.module.discriminator.parameters(), args.max_grad_norm) + scaler_disc.step(optimizer_disc) + scaler_disc.update() + + # # Log loss values: + running_loss += loss_gen.item() + loss_disc.item() + + log_steps += 1 + train_steps += 1 + if train_steps % args.log_every == 0: + # Measure training speed: + torch.cuda.synchronize() + end_time = time.time() + steps_per_sec = log_steps / (end_time - start_time) + # Reduce loss history over all processes: + avg_loss = torch.tensor(running_loss / log_steps, device=device) + dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + avg_loss = avg_loss.item() / dist.get_world_size() + logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") + # Reset monitoring variables: + running_loss = 0 + log_steps = 0 + start_time = time.time() + + # Save checkpoint: + if train_steps % args.ckpt_every == 0 and train_steps > 0: + if rank == 0: + if args.compile: + model_weight = vq_model.module._orig_mod.state_dict() + else: + model_weight = vq_model.module.state_dict() + checkpoint = { + "model": model_weight, + "optimizer": optimizer.state_dict(), + "discriminator": vq_loss.module.discriminator.state_dict(), + "optimizer_disc": optimizer_disc.state_dict(), + "steps": train_steps, + "args": args + } + if args.ema: + checkpoint["ema"] = ema.state_dict() + if not args.no_local_save: + checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" + torch.save(checkpoint, checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + + cloud_checkpoint_path = f"{cloud_checkpoint_dir}/{train_steps:07d}.pt" + torch.save(checkpoint, cloud_checkpoint_path) + logger.info(f"Saved checkpoint in cloud to {cloud_checkpoint_path}") + dist.barrier() + + vq_model.eval() # important! This disables randomized embedding dropout + # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... + + logger.info("Done!") + dist.destroy_process_group() + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default=None) + parser.add_argument("--code-path", type=str, default=None) + parser.add_argument("--data-face-path", type=str, default=None, help="face datasets to improve vq model") + parser.add_argument("--cloud-save-path", type=str, required=True, help='please specify a cloud disk path, if not, local path') + parser.add_argument("--no-local-save", action='store_true', help='no save checkpoints to local path for limited disk volume') + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") + parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for resume training") + parser.add_argument("--finetune", action='store_true', help="finetune a pre-trained vq model") + parser.add_argument("--ema", action='store_true', help="whether using ema training") + parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") + parser.add_argument("--codebook-l2-norm", action='store_true', default=True, help="l2 norm codebook") + parser.add_argument("--codebook-weight", type=float, default=1.0, help="codebook loss weight for vector quantization") + parser.add_argument("--entropy-loss-ratio", type=float, default=0.0, help="entropy loss ratio in codebook loss") + parser.add_argument("--commit-loss-beta", type=float, default=0.25, help="commit loss beta in codebook loss") + parser.add_argument("--reconstruction-weight", type=float, default=1.0, help="reconstruction loss weight of image pixel") + parser.add_argument("--reconstruction-loss", type=str, default='l2', help="reconstruction loss type of image pixel") + parser.add_argument("--perceptual-weight", type=float, default=1.0, help="perceptual loss weight of LPIPS") + parser.add_argument("--disc-weight", type=float, default=0.5, help="discriminator loss weight for gan training") + parser.add_argument("--disc-start", type=int, default=20000, help="iteration to start discriminator training and loss") + parser.add_argument("--disc-type", type=str, choices=['patchgan', 'stylegan'], default='patchgan', help="discriminator type") + parser.add_argument("--disc-loss", type=str, choices=['hinge', 'vanilla', 'non-saturating'], default='hinge', help="discriminator loss") + parser.add_argument("--gen-loss", type=str, choices=['hinge', 'non-saturating'], default='hinge', help="generator loss for gan training") + parser.add_argument("--compile", action='store_true', default=False) + parser.add_argument("--dropout-p", type=float, default=0.0, help="dropout_p") + parser.add_argument("--results-dir", type=str, default="results_tokenizer_image") + parser.add_argument("--dataset", type=str, default='imagenet') + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--epochs", type=int, default=40) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--weight-decay", type=float, default=5e-2, help="Weight decay to use.") + parser.add_argument("--beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--beta2", type=float, default=0.95, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--global-batch-size", type=int, default=64) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=16) + parser.add_argument("--log-every", type=int, default=100) + parser.add_argument("--ckpt-every", type=int, default=5000) + parser.add_argument("--gradient-accumulation-steps", type=int, default=1) + parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) + parser.add_argument("--condition", type=str, default='hed') + parser.add_argument("--get-condition-img", type=bool, default=False) + args = parser.parse_args() + main(args) diff --git a/tokenizer/vae/README.md b/tokenizer/vae/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2f57bd73ff02b35fc8509d3fed9d74d01b77bed1 --- /dev/null +++ b/tokenizer/vae/README.md @@ -0,0 +1,14 @@ +## VAE Models from Stable Diffusion + +### install +``` +pip install diffusers +pip install accelerate +``` + +### demo +``` +cd ${THIS_REPO_ROOT} +python3 tokenizer/vae/sd_vae_demo.py +``` + diff --git a/tokenizer/vae/reconstruction_vae_ddp.py b/tokenizer/vae/reconstruction_vae_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..4d674ddf4455edce800d036e9f0e97c69f72c210 --- /dev/null +++ b/tokenizer/vae/reconstruction_vae_ddp.py @@ -0,0 +1,210 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +from tqdm import tqdm +import os +import itertools +from PIL import Image +import numpy as np +import argparse +import random + +from skimage.metrics import peak_signal_noise_ratio as psnr_loss +from skimage.metrics import structural_similarity as ssim_loss +from diffusers.models import AutoencoderKL + + +class SingleFolderDataset(Dataset): + def __init__(self, directory, transform=None): + super().__init__() + self.directory = directory + self.transform = transform + self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) + if os.path.isfile(os.path.join(directory, file_name))] + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image_path = self.image_paths[idx] + image = Image.open(image_path).convert('RGB') + if self.transform: + image = self.transform(image) + return image, torch.tensor(0) + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + + random.shuffle(samples) # This is very important for IS(Inception Score) !!! + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # load vae + vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device) + + # Create folder to save samples: + folder_name = f"stabilityai-{args.vae}-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + if args.dataset == 'imagenet': + dataset = ImageFolder(args.data_path, transform=transform) + num_fid_samples = 50000 + elif args.dataset == 'coco': + dataset = SingleFolderDataset(args.data_path, transform=transform) + num_fid_samples = 5000 + else: + raise Exception("please check dataset") + + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=args.per_proc_batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False + ) + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + + psnr_val_rgb = [] + ssim_val_rgb = [] + loader = tqdm(loader) if rank == 0 else loader + total = 0 + for x, _ in loader: + rgb_gts = x + rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] + x = x.to(device) + with torch.no_grad(): + # Map input images to latent space + normalize latents: + latent = vae.encode(x).latent_dist.sample().mul_(0.18215) + # reconstruct: + samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + # metric + rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] + psnr = psnr_loss(rgb_restored, rgb_gt) + ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) + psnr_val_rgb.append(psnr) + ssim_val_rgb.append(ssim) + total += global_batch_size + + # ------------------------------------ + # Summary + # ------------------------------------ + # Make sure all processes have finished saving their samples + dist.barrier() + world_size = dist.get_world_size() + gather_psnr_val = [None for _ in range(world_size)] + gather_ssim_val = [None for _ in range(world_size)] + dist.all_gather_object(gather_psnr_val, psnr_val_rgb) + dist.all_gather_object(gather_ssim_val, ssim_val_rgb) + + if rank == 0: + gather_psnr_val = list(itertools.chain(*gather_psnr_val)) + gather_ssim_val = list(itertools.chain(*gather_ssim_val)) + psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) + ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) + + result_file = f"{sample_folder_dir}_results.txt" + print("writing results to {}".format(result_file)) + with open(result_file, 'w') as f: + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) + + create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) + print("Done.") + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') + parser.add_argument("--vae", type=str, choices=["sdxl-vae", "sd-vae-ft-mse"], default="sd-vae-ft-mse") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--sample-dir", type=str, default="reconstructions") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=4) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tokenizer/vae/sd_vae_demo.py b/tokenizer/vae/sd_vae_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5859d1ebf8a5ac43e7437fabfa8a380000c696 --- /dev/null +++ b/tokenizer/vae/sd_vae_demo.py @@ -0,0 +1,57 @@ +import argparse +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image +from diffusers.models import AutoencoderKL + + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # create and load model + vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device) + + # load image + img_path = args.image_path + out_path = args.image_path.replace('.jpg', '_vae.jpg').replace('.jpeg', '_vae.jpeg').replace('.png', '_vae.png') + input_size = args.image_size + img = Image.open(img_path).convert("RGB") + + # preprocess + size_org = img.size + img = img.resize((input_size, input_size)) + img = np.array(img) / 255. + x = 2.0 * img - 1.0 # x value is between [-1, 1] + x = torch.tensor(x) + x = x.unsqueeze(dim=0) + x = torch.einsum('nhwc->nchw', x) + x_input = x.float().to("cuda") + + # inference + with torch.no_grad(): + # Map input images to latent space + normalize latents: + latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215) + # reconstruct: + output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] + + # postprocess + output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0] + sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() + + # save + Image.fromarray(sample).save(out_path) + print("Reconstructed image is saved to {}".format(out_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image-path", type=str, default="assets/example.jpg") + parser.add_argument("--vae", type=str, choices=["sdxl-vae", "sd-vae-ft-mse"], default="sd-vae-ft-mse") + parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tokenizer/validation/val_ddp.py b/tokenizer/validation/val_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e77eebbaf915f8d10484504b148c99170aa022 --- /dev/null +++ b/tokenizer/validation/val_ddp.py @@ -0,0 +1,165 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import argparse +import random + + +class SingleFolderDataset(Dataset): + def __init__(self, directory, transform=None): + super().__init__() + self.directory = directory + self.transform = transform + self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) + if os.path.isfile(os.path.join(directory, file_name))] + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image_path = self.image_paths[idx] + image = Image.open(image_path).convert('RGB') + if self.transform: + image = self.transform(image) + return image, torch.tensor(0) + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + + random.shuffle(samples) # This is very important for IS(Inception Score) !!! + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup env + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # Create folder to save samples: + folder_name = f"val_{args.dataset}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + + if args.dataset == 'imagenet': + dataset = ImageFolder(args.data_path, transform=transform) + num_fid_samples = 50000 + elif args.dataset == 'coco': + dataset = SingleFolderDataset(args.data_path, transform=transform) + num_fid_samples = 5000 + else: + raise Exception("please check dataset") + + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=args.per_proc_batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False + ) + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + + loader = tqdm(loader) if rank == 0 else loader + total = 0 + for x, _ in loader: + samples = torch.clamp(127.5 * x + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + # Save samples to disk as individual .png files + for i, sample in enumerate(samples): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + + total += global_batch_size + + # Make sure all processes have finished saving their samples before attempting to convert to .npz + dist.barrier() + if rank == 0: + create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) + print("Done.") + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--sample-dir", type=str, default="reconstructions") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=4) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tokenizer/vqgan/README.md b/tokenizer/vqgan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9fc148622fa94347ea479067b35a2b1ee3375942 --- /dev/null +++ b/tokenizer/vqgan/README.md @@ -0,0 +1,21 @@ +## Pretrained VQVAE Models + +### install +``` +pip install omegaconf +pip install einops +``` +* download all needed models from https://github.com/CompVis/taming-transformers and put in pretrained_models/ +* pip install pytorch_lightning +* python3 tools/convert_pytorch_lightning_to_torch.py +* pip uninstall pytorch_lightning + +### demo +``` +cd ${THIS_REPO_ROOT} +python3 tokenizer/vqgan/taming_vqgan_demo.py +``` + +### acknowledge +Codes in this folder are modified from from https://github.com/CompVis/taming-transformers + diff --git a/tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml b/tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc695bee7945ead1ef5b617829b8f3b7a739eb01 --- /dev/null +++ b/tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml @@ -0,0 +1,32 @@ +model: + base_learning_rate: 4.5e-06 + target: taming.models.vqgan.VQModel + params: + embed_dim: 256 + n_embed: 1024 + ddconfig: + double_z: false + z_channels: 256 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 0 + disc_weight: 0.8 + codebook_weight: 1.0 + diff --git a/tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml b/tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ba40acebbaa88d3416a347477ccc3ac5ed75d56 --- /dev/null +++ b/tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml @@ -0,0 +1,34 @@ +model: + base_learning_rate: 4.5e-06 + target: taming.models.vqgan.VQModel + params: + embed_dim: 256 + n_embed: 16384 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 256 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 0 + disc_weight: 0.75 + disc_num_layers: 2 + codebook_weight: 1.0 + diff --git a/tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml b/tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml new file mode 100644 index 0000000000000000000000000000000000000000..74fad8a3fe69db7dcd740ef0906eff543013d058 --- /dev/null +++ b/tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml @@ -0,0 +1,20 @@ +model: + params: + embed_dim: 4 + n_embed: 16384 + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 \ No newline at end of file diff --git a/tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml b/tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..229adcf6b5ddfcdda8179b78e83985fc28eebf67 --- /dev/null +++ b/tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml @@ -0,0 +1,20 @@ +model: + params: + embed_dim: 4 + n_embed: 256 + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 \ No newline at end of file diff --git a/tokenizer/vqgan/layer.py b/tokenizer/vqgan/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..66ae99a48b62d469ad17f52be2b6902e83a0573b --- /dev/null +++ b/tokenizer/vqgan/layer.py @@ -0,0 +1,372 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + diff --git a/tokenizer/vqgan/model.py b/tokenizer/vqgan/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb79bec88316cbf8e23e0860e8d22b63934a172 --- /dev/null +++ b/tokenizer/vqgan/model.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tokenizer.vqgan.layer import Encoder, Decoder +from tokenizer.vqgan.quantize import VectorQuantizer2 as VectorQuantizer + + +VQGAN_FROM_TAMING = { + 'vqgan_imagenet_f16_1024': ( + 'tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml', + 'pretrained_models/vqgan_imagenet_f16_1024/ckpts/last.pth'), + 'vqgan_imagenet_f16_16384': ( + 'tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml', + 'pretrained_models/vqgan_imagenet_f16_16384/ckpts/last.pth'), + 'vqgan_openimage_f8_256': ( + 'tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml', + 'pretrained_models/vq-f8-n256/model.pth'), + 'vqgan_openimage_f8_16384': ( + 'tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml', + 'pretrained_models/vq-f8/model.pth'), +} + +class VQModel(nn.Module): + def __init__(self, + ddconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + **kwargs, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def init_from_ckpt(self, path, ignore_keys=list(), logging=True): + model_weight = torch.load(path, map_location="cpu")["state_dict"] + keys = list(model_weight.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del model_weight[k] + missing, unexpected = self.load_state_dict(model_weight, strict=False) + if logging: + print(f"Restored from {path}") + print(f"Missing Keys in State Dict: {missing}") + print(f"Unexpected Keys in State Dict: {unexpected}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff diff --git a/tokenizer/vqgan/quantize.py b/tokenizer/vqgan/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..7239b72c9c5d062b9a44fb8f4e8def871ff5ef7b --- /dev/null +++ b/tokenizer/vqgan/quantize.py @@ -0,0 +1,229 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + + return z_q \ No newline at end of file diff --git a/tokenizer/vqgan/reconstruction_vqgan_ddp.py b/tokenizer/vqgan/reconstruction_vqgan_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..2af696cc7f1b49e0d9cf00080b2f381e0923431e --- /dev/null +++ b/tokenizer/vqgan/reconstruction_vqgan_ddp.py @@ -0,0 +1,215 @@ +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import itertools +import argparse +import random + +from skimage.metrics import peak_signal_noise_ratio as psnr_loss +from skimage.metrics import structural_similarity as ssim_loss +from omegaconf import OmegaConf +from tokenizer.vqgan.model import VQModel +from tokenizer.vqgan.model import VQGAN_FROM_TAMING + + +class SingleFolderDataset(Dataset): + def __init__(self, directory, transform=None): + super().__init__() + self.directory = directory + self.transform = transform + self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) + if os.path.isfile(os.path.join(directory, file_name))] + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image_path = self.image_paths[idx] + image = Image.open(image_path).convert('RGB') + if self.transform: + image = self.transform(image) + return image, torch.tensor(0) + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + + random.shuffle(samples) # This is very important for IS(Inception Score) !!! + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def main(args): + # Setup PyTorch: + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # create and load vqgan + cfg, ckpt = VQGAN_FROM_TAMING[args.vqgan] + config = OmegaConf.load(cfg) + vq_model = VQModel(**config.model.get("params", dict())).to(device) + vq_model.init_from_ckpt(ckpt, logging=False) + vq_model.eval() + + # Create folder to save samples: + folder_name = f"{args.vqgan}-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + + if args.dataset == 'imagenet': + dataset = ImageFolder(args.data_path, transform=transform) + num_fid_samples = 50000 + elif args.dataset == 'coco': + dataset = SingleFolderDataset(args.data_path, transform=transform) + num_fid_samples = 5000 + else: + raise Exception("please check dataset") + + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=args.per_proc_batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False + ) + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + + psnr_val_rgb = [] + ssim_val_rgb = [] + loader = tqdm(loader) if rank == 0 else loader + total = 0 + for x, _ in loader: + rgb_gts = x + rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] + x = x.to(device) + with torch.no_grad(): + latent, _, [_, _, indices] = vq_model.encode(x) + samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1] + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + # metric + rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] + psnr = psnr_loss(rgb_restored, rgb_gt) + ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) + psnr_val_rgb.append(psnr) + ssim_val_rgb.append(ssim) + total += global_batch_size + + # ------------------------------------ + # Summary + # ------------------------------------ + # Make sure all processes have finished saving their samples + dist.barrier() + world_size = dist.get_world_size() + gather_psnr_val = [None for _ in range(world_size)] + gather_ssim_val = [None for _ in range(world_size)] + dist.all_gather_object(gather_psnr_val, psnr_val_rgb) + dist.all_gather_object(gather_ssim_val, ssim_val_rgb) + + if rank == 0: + gather_psnr_val = list(itertools.chain(*gather_psnr_val)) + gather_ssim_val = list(itertools.chain(*gather_ssim_val)) + psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) + ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) + + result_file = f"{sample_folder_dir}_results.txt" + print("writing results to {}".format(result_file)) + with open(result_file, 'w') as f: + print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) + + create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) + print("Done.") + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') + parser.add_argument("--vqgan", type=str, choices=list(VQGAN_FROM_TAMING.keys()), default="vqgan_imagenet_f16_16384") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--sample-dir", type=str, default="reconstructions") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=4) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tokenizer/vqgan/taming_vqgan_demo.py b/tokenizer/vqgan/taming_vqgan_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..176b739df21e147a93e502dbb2d070edf518b876 --- /dev/null +++ b/tokenizer/vqgan/taming_vqgan_demo.py @@ -0,0 +1,68 @@ +import argparse +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +from tokenizer.vqgan.model import VQModel +from tokenizer.vqgan.model import VQGAN_FROM_TAMING + +# before running demo, make sure to: +# (1) download all needed models from https://github.com/CompVis/taming-transformers and put in pretrained_models/ +# (2) pip install pytorch_lightning +# (3) python3 tools/convert_pytorch_lightning_to_torch.py +# (4) pip uninstall pytorch_lightning + + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # create and load model + cfg, ckpt = VQGAN_FROM_TAMING[args.vqgan] + config = OmegaConf.load(cfg) + model = VQModel(**config.model.get("params", dict())) + model.init_from_ckpt(ckpt) + model.to(device) + model.eval() + + # load image + img_path = args.image_path + out_path = args.image_path.replace('.jpg', '_vqgan.jpg').replace('.jpeg', '_vqgan.jpeg').replace('.png', '_vqgan.png') + input_size = args.image_size + img = Image.open(img_path).convert("RGB") + + # preprocess + size_org = img.size + img = img.resize((input_size, input_size)) + img = np.array(img) / 255. + x = 2.0 * img - 1.0 # x value is between [-1, 1] + x = torch.tensor(x) + x = x.unsqueeze(dim=0) + x = torch.einsum('nhwc->nchw', x) + x_input = x.float().to("cuda") + + # inference + with torch.no_grad(): + latent, _, [_, _, indices] = model.encode(x_input) + output = model.decode_code(indices, latent.shape) # output value is between [-1, 1] + + # postprocess + output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0] + sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() + + # save + Image.fromarray(sample).save(out_path) + print("Reconstructed image is saved to {}".format(out_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image-path", type=str, default="assets/example.jpg") + parser.add_argument("--vqgan", type=str, choices=list(VQGAN_FROM_TAMING.keys()), default="vqgan_openimage_f8_16384") + parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + main(args) diff --git a/utils/data.py b/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..313e7b94e595f1d8e94b1ff6afa86678791440ce --- /dev/null +++ b/utils/data.py @@ -0,0 +1,22 @@ +import numpy as np +from PIL import Image + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) \ No newline at end of file diff --git a/utils/deepspeed.py b/utils/deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..15cc757853126673096238676494b0e8a95347b8 --- /dev/null +++ b/utils/deepspeed.py @@ -0,0 +1,87 @@ +def create_deepspeed_config(args): + ds_config = { + "steps_per_print": 1000, + "train_batch_size": args.global_batch_size, + "gradient_accumulation_steps": args.gradient_accumulation_steps, + # "train_micro_batch_size_per_gpu": args.batch_size, # determined by (train_batch_size, gradient_accumulation_steps) + "optimizer": { + "type": "Adam", + "adam_w_mode": True, + "params": { + "lr": args.lr, + "weight_decay": args.weight_decay, + "bias_correction": True, + "betas": [ + args.beta1, + args.beta2 + ], + } + }, + "fp16": { + "enabled": args.mixed_precision == 'fp16', + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": args.mixed_precision == 'bf16', + }, + # "flops_profiler": { + # "enabled": True, + # "profile_step": -1, + # "module_depth": -1, + # "top_modules": 1, + # "detailed": True, + # }, + "zero_allow_untested_optimizer": True + } + + if args.clip_grad is not None: + ds_config.update({'gradient_clipping': args.clip_grad}) + + if args.zero_stage == 0: + ds_config.update({"zero_optimization": + { + "stage": args.zero_stage, + "contiguous_gradients": True, + "overlap_comm": True, + } + }) + elif args.zero_stage == 1: + ds_config.update({"zero_optimization": + { + "stage": args.zero_stage, + "contiguous_gradients": True, + "overlap_comm": True, + "reduce_bucket_size": 5e8, + } + }) + elif args.zero_stage == 2: + ds_config.update({"zero_optimization": + { + "stage": args.zero_stage, + "contiguous_gradients": True, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8, + } + }) + elif args.zero_stage == 3: + ds_config.update({"zero_optimization": + { + "stage": args.zero_stage, + "contiguous_gradients": True, + "overlap_comm": True, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": 5e8, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": True + } + }) + + return ds_config diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..9e91d863bb489d178a14ce2721782bae038714aa --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,58 @@ +import os +import torch +import subprocess + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + args.dist_url = 'env://' + os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) + elif 'SLURM_PROCID' in os.environ: + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + 'scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['LOCAL_SIZE'] = str(num_gpus) + args.dist_url = 'env://' + args.world_size = ntasks + args.rank = proc_id + args.gpu = proc_id % num_gpus + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) diff --git a/utils/drop_path.py b/utils/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..991a323e33304a2fe71c9757c13ace5c89fd8a5a --- /dev/null +++ b/utils/drop_path.py @@ -0,0 +1,36 @@ +# from timm.models.layers import DropPath +import torch + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(torch.nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' \ No newline at end of file diff --git a/utils/ema.py b/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd4d29f463320ce6c35437402cf913743f9e427 --- /dev/null +++ b/utils/ema.py @@ -0,0 +1,22 @@ +import torch +from collections import OrderedDict + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..87dbef04973484e1f6fd29d1fd9d0c7795c25e8e --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,19 @@ +import logging +import torch.distributed as dist + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format='[\033[34m%(asctime)s\033[0m] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger \ No newline at end of file diff --git a/utils/video.py b/utils/video.py new file mode 100644 index 0000000000000000000000000000000000000000..4142c72fcb0eff187bd4f8d5880b06585237e1ca --- /dev/null +++ b/utils/video.py @@ -0,0 +1,116 @@ +import math +import numpy as np +import skvideo.io +from PIL import Image + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def save_video_grid(video, fname, nrow=None, fps=5): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '{}'.format(fps)}) + + +def save_gif_grid(video, file_name, nrow=None, fps=5): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + images = [] + for frame in video_grid: + images.append(Image.fromarray(frame)) + + # Save the first image and append the rest of the images as frames in the GIF + images[0].save(file_name, save_all=True, append_images=images[1:], optimize=False, duration=int(1000/fps), loop=0) + + # The 'duration' parameter defines the display time for each frame in milliseconds + # The 'loop' parameter defines the number of loops the GIF should make (0 for infinite loop)