hysts HF staff commited on
Commit
639c25d
1 Parent(s): 278b1b5
Files changed (11) hide show
  1. .gitignore +162 -0
  2. .gitmodules +3 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. Dockerfile +53 -0
  6. README.md +3 -3
  7. app.py +105 -0
  8. model.py +515 -0
  9. requirements.txt +13 -0
  10. style.css +3 -0
  11. unidiffuser +1 -0
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models/
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "unidiffuser"]
2
+ path = unidiffuser
3
+ url = https://github.com/thu-ml/unidiffuser
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Dockerfile ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # python build dependencies \
11
+ build-essential \
12
+ libssl-dev \
13
+ zlib1g-dev \
14
+ libbz2-dev \
15
+ libreadline-dev \
16
+ libsqlite3-dev \
17
+ libncursesw5-dev \
18
+ xz-utils \
19
+ tk-dev \
20
+ libxml2-dev \
21
+ libxmlsec1-dev \
22
+ libffi-dev \
23
+ liblzma-dev && \
24
+ apt-get clean && \
25
+ rm -rf /var/lib/apt/lists/*
26
+
27
+ RUN useradd -m -u 1000 user
28
+ USER user
29
+ ENV HOME=/home/user \
30
+ PATH=/home/user/.local/bin:${PATH}
31
+ WORKDIR ${HOME}/app
32
+
33
+ RUN curl https://pyenv.run | bash
34
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
35
+ ARG PYTHON_VERSION=3.10.10
36
+ RUN pyenv install ${PYTHON_VERSION} && \
37
+ pyenv global ${PYTHON_VERSION} && \
38
+ pyenv rehash && \
39
+ pip install --no-cache-dir -U pip setuptools wheel
40
+
41
+ RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
42
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
43
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
44
+
45
+ COPY --chown=1000 . ${HOME}/app
46
+ ENV PYTHONPATH=${HOME}/app \
47
+ PYTHONUNBUFFERED=1 \
48
+ GRADIO_ALLOW_FLAGGING=never \
49
+ GRADIO_NUM_PORTS=1 \
50
+ GRADIO_SERVER_NAME=0.0.0.0 \
51
+ GRADIO_THEME=huggingface \
52
+ SYSTEM=spaces
53
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Unidiffuser
3
  emoji: 😻
4
  colorFrom: gray
5
  colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.20.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: UniDiffuser
3
  emoji: 😻
4
  colorFrom: gray
5
  colorTo: green
6
+ sdk: docker
 
7
  app_file: app.py
8
  pinned: false
9
+ license: other
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from model import Model
10
+
11
+ DESCRIPTION = '# [UniDiffuser](https://github.com/thu-ml/unidiffuser)'
12
+
13
+ SPACE_ID = os.getenv('SPACE_ID')
14
+ if SPACE_ID is not None:
15
+ DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
16
+
17
+ model = Model()
18
+
19
+
20
+ def create_demo(mode_name: str) -> gr.Blocks:
21
+ with gr.Blocks() as demo:
22
+ with gr.Row():
23
+ with gr.Column():
24
+ mode = gr.Dropdown(label='Mode',
25
+ choices=[
26
+ 't2i',
27
+ 'i2t',
28
+ 'joint',
29
+ 'i',
30
+ 't',
31
+ 'i2ti2',
32
+ 't2i2t',
33
+ ],
34
+ value=mode_name,
35
+ visible=False)
36
+ prompt = gr.Text(label='Prompt',
37
+ max_lines=1,
38
+ visible=mode_name in ['t2i', 't2i2t'])
39
+ image = gr.Image(label='Input image',
40
+ type='filepath',
41
+ visible=mode_name in ['i2t', 'i2t2i'])
42
+ run_button = gr.Button('Run')
43
+ with gr.Accordion('Advanced options', open=False):
44
+ seed = gr.Slider(
45
+ label='Seed',
46
+ minimum=-1,
47
+ maximum=1000000,
48
+ step=1,
49
+ value=-1,
50
+ info=
51
+ 'If set to -1, a different seed will be used each time.'
52
+ )
53
+ num_steps = gr.Slider(label='Steps',
54
+ minimum=1,
55
+ maximum=100,
56
+ value=50,
57
+ step=1)
58
+ guidance_scale = gr.Slider(label='Guidance Scale',
59
+ minimum=0.1,
60
+ maximum=30.0,
61
+ value=7.0,
62
+ step=0.1)
63
+ with gr.Column():
64
+ result_image = gr.Image(label='Generated image',
65
+ visible=mode_name
66
+ in ['t2i', 'i', 'joint', 'i2t2i'])
67
+ result_text = gr.Text(label='Generated text',
68
+ visible=mode_name
69
+ in ['i2t', 't', 'joint', 't2i2t'])
70
+ inputs = [
71
+ mode,
72
+ prompt,
73
+ image,
74
+ seed,
75
+ num_steps,
76
+ guidance_scale,
77
+ ]
78
+ outputs = [
79
+ result_image,
80
+ result_text,
81
+ ]
82
+
83
+ prompt.submit(fn=model.run, inputs=inputs, outputs=outputs)
84
+ run_button.click(fn=model.run, inputs=inputs, outputs=outputs)
85
+ return demo
86
+
87
+
88
+ with gr.Blocks(css='style.css') as demo:
89
+ gr.Markdown(DESCRIPTION)
90
+ with gr.Tabs():
91
+ with gr.TabItem('text2image'):
92
+ create_demo('t2i')
93
+ with gr.TabItem('image2text'):
94
+ create_demo('i2t')
95
+ with gr.TabItem('image variation'):
96
+ create_demo('i2t2i')
97
+ with gr.TabItem('joint generation'):
98
+ create_demo('joint')
99
+ with gr.TabItem('image generation'):
100
+ create_demo('i')
101
+ with gr.TabItem('text generation'):
102
+ create_demo('t')
103
+ with gr.TabItem('text variation'):
104
+ create_demo('t2i2t')
105
+ demo.queue(api_open=False).launch()
model.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+ import random
5
+ import sys
6
+ from typing import Callable
7
+
8
+ import clip
9
+ import einops
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ from huggingface_hub import snapshot_download
14
+
15
+ repo_dir = pathlib.Path(__file__).parent
16
+ submodule_dir = repo_dir / 'unidiffuser'
17
+ sys.path.append(submodule_dir.as_posix())
18
+
19
+ import utils
20
+ from configs.sample_unidiffuser_v1 import get_config
21
+ from dpm_solver_pp import DPM_Solver, NoiseScheduleVP
22
+ from libs.autoencoder import FrozenAutoencoderKL
23
+ from libs.autoencoder import get_model as get_autoencoder
24
+ from libs.caption_decoder import CaptionDecoder
25
+ from libs.clip import FrozenCLIPEmbedder
26
+
27
+ model_dir = repo_dir / 'models'
28
+ if not model_dir.exists():
29
+ snapshot_download('thu-ml/unidiffuser-v1',
30
+ repo_type='model',
31
+ local_dir=model_dir)
32
+
33
+
34
+ def stable_diffusion_beta_schedule(linear_start=0.00085,
35
+ linear_end=0.0120,
36
+ n_timestep=1000):
37
+ _betas = (torch.linspace(linear_start**0.5,
38
+ linear_end**0.5,
39
+ n_timestep,
40
+ dtype=torch.float64)**2)
41
+ return _betas.numpy()
42
+
43
+
44
+ class Model:
45
+ def __init__(self):
46
+ self.device = torch.device(
47
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
48
+ self.config = get_config()
49
+
50
+ self.nnet = self.load_model()
51
+ self.caption_decoder = CaptionDecoder(device=self.device,
52
+ **self.config.caption_decoder)
53
+ self.clip_text_model = self.load_clip_text_model()
54
+ self.autoencoder = self.load_autoencoder()
55
+
56
+ self.clip_img_model, self.clip_img_model_preprocess = clip.load(
57
+ 'ViT-B/32', device=self.device, jit=False)
58
+ self.empty_context = self.clip_text_model.encode([''])[0]
59
+
60
+ self.betas = stable_diffusion_beta_schedule()
61
+ self.N = len(self.betas)
62
+
63
+ @property
64
+ def use_caption_decoder(self) -> bool:
65
+ return (self.config.text_dim < self.config.clip_text_dim
66
+ or self.config.mode != 't2i')
67
+
68
+ def load_model(self,
69
+ model_path: str = 'models/uvit_v1.pth') -> torch.nn.Module:
70
+ model = utils.get_nnet(**self.config.nnet)
71
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
72
+ model.to(self.device)
73
+ model.eval()
74
+ return model
75
+
76
+ def load_clip_text_model(self) -> FrozenCLIPEmbedder:
77
+ clip_text_model = FrozenCLIPEmbedder(device=self.device)
78
+ clip_text_model.to(self.device)
79
+ clip_text_model.eval()
80
+ return clip_text_model
81
+
82
+ def load_autoencoder(self) -> FrozenAutoencoderKL:
83
+ autoencoder = get_autoencoder(**self.config.autoencoder)
84
+ autoencoder.to(self.device)
85
+ return autoencoder
86
+
87
+ def split(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
88
+ C, H, W = self.config.z_shape
89
+ z_dim = C * H * W
90
+ z, clip_img = x.split([z_dim, self.config.clip_img_dim], dim=1)
91
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
92
+ clip_img = einops.rearrange(clip_img,
93
+ 'B (L D) -> B L D',
94
+ L=1,
95
+ D=self.config.clip_img_dim)
96
+ return z, clip_img
97
+
98
+ @staticmethod
99
+ def combine(z, clip_img):
100
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
101
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
102
+ return torch.concat([z, clip_img], dim=-1)
103
+
104
+ def t2i_nnet(
105
+ self, x, timesteps, text
106
+ ): # text is the low dimension version of the text clip embedding
107
+ """
108
+ 1. calculate the conditional model output
109
+ 2. calculate unconditional model output
110
+ config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
111
+ config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
112
+ 3. return linear combination of conditional output and unconditional output
113
+ """
114
+ z, clip_img = self.split(x)
115
+
116
+ t_text = torch.zeros(timesteps.size(0),
117
+ dtype=torch.int,
118
+ device=self.device)
119
+
120
+ z_out, clip_img_out, text_out = self.nnet(
121
+ z,
122
+ clip_img,
123
+ text=text,
124
+ t_img=timesteps,
125
+ t_text=t_text,
126
+ data_type=torch.zeros_like(
127
+ t_text, device=self.device, dtype=torch.int) +
128
+ self.config.data_type)
129
+ x_out = self.combine(z_out, clip_img_out)
130
+
131
+ if self.config.sample.scale == 0.:
132
+ return x_out
133
+
134
+ if self.config.sample.t2i_cfg_mode == 'empty_token':
135
+ _empty_context = einops.repeat(self.empty_context,
136
+ 'L D -> B L D',
137
+ B=x.size(0))
138
+ if self.use_caption_decoder:
139
+ _empty_context = self.caption_decoder.encode_prefix(
140
+ _empty_context)
141
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
142
+ z,
143
+ clip_img,
144
+ text=_empty_context,
145
+ t_img=timesteps,
146
+ t_text=t_text,
147
+ data_type=torch.zeros_like(
148
+ t_text, device=self.device, dtype=torch.int) +
149
+ self.config.data_type)
150
+ x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond)
151
+ elif self.config.sample.t2i_cfg_mode == 'true_uncond':
152
+ text_N = torch.randn_like(text) # 3 other possible choices
153
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
154
+ z,
155
+ clip_img,
156
+ text=text_N,
157
+ t_img=timesteps,
158
+ t_text=torch.ones_like(timesteps) * self.N,
159
+ data_type=torch.zeros_like(
160
+ t_text, device=self.device, dtype=torch.int) +
161
+ self.config.data_type)
162
+ x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond)
163
+ else:
164
+ raise NotImplementedError
165
+
166
+ return x_out + self.config.sample.scale * (x_out - x_out_uncond)
167
+
168
+ def i_nnet(self, x, timesteps):
169
+ z, clip_img = self.split(x)
170
+ text = torch.randn(x.size(0),
171
+ 77,
172
+ self.config.text_dim,
173
+ device=self.device)
174
+ t_text = torch.ones_like(timesteps) * self.N
175
+ z_out, clip_img_out, text_out = self.nnet(
176
+ z,
177
+ clip_img,
178
+ text=text,
179
+ t_img=timesteps,
180
+ t_text=t_text,
181
+ data_type=torch.zeros_like(
182
+ t_text, device=self.device, dtype=torch.int) +
183
+ self.config.data_type)
184
+ x_out = self.combine(z_out, clip_img_out)
185
+ return x_out
186
+
187
+ def t_nnet(self, x, timesteps):
188
+ z = torch.randn(x.size(0), *self.config.z_shape, device=self.device)
189
+ clip_img = torch.randn(x.size(0),
190
+ 1,
191
+ self.config.clip_img_dim,
192
+ device=self.device)
193
+ z_out, clip_img_out, text_out = self.nnet(
194
+ z,
195
+ clip_img,
196
+ text=x,
197
+ t_img=torch.ones_like(timesteps) * self.N,
198
+ t_text=timesteps,
199
+ data_type=torch.zeros_like(
200
+ timesteps, device=self.device, dtype=torch.int) +
201
+ self.config.data_type)
202
+ return text_out
203
+
204
+ def i2t_nnet(self, x, timesteps, z, clip_img):
205
+ """
206
+ 1. calculate the conditional model output
207
+ 2. calculate unconditional model output
208
+ 3. return linear combination of conditional output and unconditional output
209
+ """
210
+ t_img = torch.zeros(timesteps.size(0),
211
+ dtype=torch.int,
212
+ device=self.device)
213
+
214
+ z_out, clip_img_out, text_out = self.nnet(
215
+ z,
216
+ clip_img,
217
+ text=x,
218
+ t_img=t_img,
219
+ t_text=timesteps,
220
+ data_type=torch.zeros_like(
221
+ t_img, device=self.device, dtype=torch.int) +
222
+ self.config.data_type)
223
+
224
+ if self.config.sample.scale == 0.:
225
+ return text_out
226
+
227
+ z_N = torch.randn_like(z) # 3 other possible choices
228
+ clip_img_N = torch.randn_like(clip_img)
229
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
230
+ z_N,
231
+ clip_img_N,
232
+ text=x,
233
+ t_img=torch.ones_like(timesteps) * self.N,
234
+ t_text=timesteps,
235
+ data_type=torch.zeros_like(
236
+ timesteps, device=self.device, dtype=torch.int) +
237
+ self.config.data_type)
238
+
239
+ return text_out + self.config.sample.scale * (text_out -
240
+ text_out_uncond)
241
+
242
+ def split_joint(self, x):
243
+ C, H, W = self.config.z_shape
244
+ z_dim = C * H * W
245
+ z, clip_img, text = x.split(
246
+ [z_dim, self.config.clip_img_dim, 77 * self.config.text_dim],
247
+ dim=1)
248
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
249
+ clip_img = einops.rearrange(clip_img,
250
+ 'B (L D) -> B L D',
251
+ L=1,
252
+ D=self.config.clip_img_dim)
253
+ text = einops.rearrange(text,
254
+ 'B (L D) -> B L D',
255
+ L=77,
256
+ D=self.config.text_dim)
257
+ return z, clip_img, text
258
+
259
+ @staticmethod
260
+ def combine_joint(z: torch.Tensor, clip_img: torch.Tensor,
261
+ text: torch.Tensor) -> torch.Tensor:
262
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
263
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
264
+ text = einops.rearrange(text, 'B L D -> B (L D)')
265
+ return torch.concat([z, clip_img, text], dim=-1)
266
+
267
+ def joint_nnet(self, x, timesteps):
268
+ z, clip_img, text = self.split_joint(x)
269
+ z_out, clip_img_out, text_out = self.nnet(
270
+ z,
271
+ clip_img,
272
+ text=text,
273
+ t_img=timesteps,
274
+ t_text=timesteps,
275
+ data_type=torch.zeros_like(
276
+ timesteps, device=self.device, dtype=torch.int) +
277
+ self.config.data_type)
278
+ x_out = self.combine_joint(z_out, clip_img_out, text_out)
279
+
280
+ if self.config.sample.scale == 0.:
281
+ return x_out
282
+
283
+ z_noise = torch.randn(x.size(0),
284
+ *self.config.z_shape,
285
+ device=self.device)
286
+ clip_img_noise = torch.randn(x.size(0),
287
+ 1,
288
+ self.config.clip_img_dim,
289
+ device=self.device)
290
+ text_noise = torch.randn(x.size(0),
291
+ 77,
292
+ self.config.text_dim,
293
+ device=self.device)
294
+
295
+ _, _, text_out_uncond = self.nnet(
296
+ z_noise,
297
+ clip_img_noise,
298
+ text=text,
299
+ t_img=torch.ones_like(timesteps) * self.N,
300
+ t_text=timesteps,
301
+ data_type=torch.zeros_like(
302
+ timesteps, device=self.device, dtype=torch.int) +
303
+ self.config.data_type)
304
+ z_out_uncond, clip_img_out_uncond, _ = self.nnet(
305
+ z,
306
+ clip_img,
307
+ text=text_noise,
308
+ t_img=timesteps,
309
+ t_text=torch.ones_like(timesteps) * self.N,
310
+ data_type=torch.zeros_like(
311
+ timesteps, device=self.device, dtype=torch.int) +
312
+ self.config.data_type)
313
+
314
+ x_out_uncond = self.combine_joint(z_out_uncond, clip_img_out_uncond,
315
+ text_out_uncond)
316
+
317
+ return x_out + self.config.sample.scale * (x_out - x_out_uncond)
318
+
319
+ @torch.cuda.amp.autocast()
320
+ def encode(self, _batch):
321
+ return self.autoencoder.encode(_batch)
322
+
323
+ @torch.cuda.amp.autocast()
324
+ def decode(self, _batch):
325
+ return self.autoencoder.decode(_batch)
326
+
327
+ def prepare_contexts(
328
+ self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
329
+ resolution = self.config.z_shape[-1] * 8
330
+
331
+ contexts = torch.randn(self.config.n_samples, 77,
332
+ self.config.clip_text_dim).to(self.device)
333
+ img_contexts = torch.randn(self.config.n_samples,
334
+ 2 * self.config.z_shape[0],
335
+ self.config.z_shape[1],
336
+ self.config.z_shape[2])
337
+ clip_imgs = torch.randn(self.config.n_samples, 1,
338
+ self.config.clip_img_dim)
339
+
340
+ if self.config.mode in ['t2i', 't2i2t']:
341
+ prompts = [self.config.prompt] * self.config.n_samples
342
+ contexts = self.clip_text_model.encode(prompts)
343
+
344
+ elif self.config.mode in ['i2t', 'i2t2i']:
345
+ img_contexts = []
346
+ clip_imgs = []
347
+
348
+ def get_img_feature(image):
349
+ image = np.array(image).astype(np.uint8)
350
+ image = utils.center_crop(resolution, resolution, image)
351
+ clip_img_feature = self.clip_img_model.encode_image(
352
+ self.clip_img_model_preprocess(
353
+ PIL.Image.fromarray(image)).unsqueeze(0).to(
354
+ self.device))
355
+
356
+ image = (image / 127.5 - 1.0).astype(np.float32)
357
+ image = einops.rearrange(image, 'h w c -> 1 c h w')
358
+ image = torch.tensor(image, device=self.device)
359
+ moments = self.autoencoder.encode_moments(image)
360
+
361
+ return clip_img_feature, moments
362
+
363
+ image = PIL.Image.open(self.config.img).convert('RGB')
364
+ clip_img, img_context = get_img_feature(image)
365
+
366
+ img_contexts.append(img_context)
367
+ clip_imgs.append(clip_img)
368
+ img_contexts = img_contexts * self.config.n_samples
369
+ clip_imgs = clip_imgs * self.config.n_samples
370
+
371
+ img_contexts = torch.concat(img_contexts, dim=0)
372
+ clip_imgs = torch.stack(clip_imgs, dim=0)
373
+
374
+ return contexts, img_contexts, clip_imgs
375
+
376
+ @staticmethod
377
+ def unpreprocess(v: torch.Tensor) -> torch.Tensor: # to B C H W and [0, 1]
378
+ v = 0.5 * (v + 1.)
379
+ v.clamp_(0., 1.)
380
+ return v
381
+
382
+ def get_sample_fn(self, _n_samples: int) -> Callable:
383
+ def sample_fn(mode: str, **kwargs):
384
+ _z_init = torch.randn(_n_samples,
385
+ *self.config.z_shape,
386
+ device=self.device)
387
+ _clip_img_init = torch.randn(_n_samples,
388
+ 1,
389
+ self.config.clip_img_dim,
390
+ device=self.device)
391
+ _text_init = torch.randn(_n_samples,
392
+ 77,
393
+ self.config.text_dim,
394
+ device=self.device)
395
+ if mode == 'joint':
396
+ _x_init = self.combine_joint(_z_init, _clip_img_init,
397
+ _text_init)
398
+ elif mode in ['t2i', 'i']:
399
+ _x_init = self.combine(_z_init, _clip_img_init)
400
+ elif mode in ['i2t', 't']:
401
+ _x_init = _text_init
402
+ noise_schedule = NoiseScheduleVP(schedule='discrete',
403
+ betas=torch.tensor(
404
+ self.betas,
405
+ device=self.device).float())
406
+
407
+ def model_fn(x, t_continuous):
408
+ t = t_continuous * self.N
409
+ if mode == 'joint':
410
+ return self.joint_nnet(x, t)
411
+ elif mode == 't2i':
412
+ return self.t2i_nnet(x, t, **kwargs)
413
+ elif mode == 'i2t':
414
+ return self.i2t_nnet(x, t, **kwargs)
415
+ elif mode == 'i':
416
+ return self.i_nnet(x, t)
417
+ elif mode == 't':
418
+ return self.t_nnet(x, t)
419
+
420
+ dpm_solver = DPM_Solver(model_fn,
421
+ noise_schedule,
422
+ predict_x0=True,
423
+ thresholding=False)
424
+ with torch.inference_mode(), torch.autocast(
425
+ device_type=self.device.type):
426
+ x = dpm_solver.sample(_x_init,
427
+ steps=self.config.sample.sample_steps,
428
+ eps=1. / self.N,
429
+ T=1.)
430
+
431
+ if mode == 'joint':
432
+ _z, _clip_img, _text = self.split_joint(x)
433
+ return _z, _clip_img, _text
434
+ elif mode in ['t2i', 'i']:
435
+ _z, _clip_img = self.split(x)
436
+ return _z, _clip_img
437
+ elif mode in ['i2t', 't']:
438
+ return x
439
+
440
+ return sample_fn
441
+
442
+ @staticmethod
443
+ def to_pil(tensor: torch.Tensor) -> PIL.Image.Image:
444
+ return PIL.Image.fromarray(
445
+ tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
446
+ 'cpu', torch.uint8).numpy())
447
+
448
+ def run(self, mode: str, prompt: str, image_path: str, seed: int,
449
+ num_steps: int,
450
+ guidance_scale: float) -> tuple[PIL.Image.Image | None, str]:
451
+ self.config.mode = mode
452
+ self.config.prompt = prompt
453
+ self.config.img = image_path
454
+ self.config.seed = seed
455
+ self.config.sample.sample_steps = num_steps
456
+ self.config.sample.scale = guidance_scale
457
+ self.config.n_samples = 1
458
+
459
+ #set_seed(self.config.seed)
460
+ if seed == -1:
461
+ seed = random.randint(0, 1000000)
462
+ torch.manual_seed(seed)
463
+
464
+ contexts, img_contexts, clip_imgs = self.prepare_contexts()
465
+ if self.use_caption_decoder:
466
+ contexts_low_dim = self.caption_decoder.encode_prefix(contexts)
467
+ else:
468
+ contexts_low_dim = contexts
469
+ z_img = self.autoencoder.sample(img_contexts)
470
+
471
+ if self.config.mode in ['t2i', 't2i2t']:
472
+ _n_samples = contexts_low_dim.size(0)
473
+ elif self.config.mode in ['i2t', 'i2t2i']:
474
+ _n_samples = img_contexts.size(0)
475
+ else:
476
+ _n_samples = self.config.n_samples
477
+ sample_fn = self.get_sample_fn(_n_samples)
478
+
479
+ if self.config.mode == 'joint':
480
+ _z, _clip_img, _text = sample_fn(self.config.mode)
481
+ samples = self.unpreprocess(self.decode(_z))
482
+ samples = [self.to_pil(tensor) for tensor in samples]
483
+ prompts = self.caption_decoder.generate_captions(_text)
484
+ return samples[0], prompts[0]
485
+
486
+ elif self.config.mode in ['t2i', 'i', 'i2t2i']:
487
+ if self.config.mode == 't2i':
488
+ _z, _clip_img = sample_fn(
489
+ self.config.mode,
490
+ text=contexts_low_dim) # conditioned on the text embedding
491
+ elif self.config.mode == 'i':
492
+ _z, _clip_img = sample_fn(self.config.mode)
493
+ elif self.config.mode == 'i2t2i':
494
+ _text = sample_fn(
495
+ 'i2t', z=z_img,
496
+ clip_img=clip_imgs) # conditioned on the image embedding
497
+ _z, _clip_img = sample_fn('t2i', text=_text)
498
+ samples = self.unpreprocess(self.decode(_z))
499
+ samples = [self.to_pil(tensor) for tensor in samples]
500
+ return samples[0], ''
501
+
502
+ elif self.config.mode in ['i2t', 't', 't2i2t']:
503
+ if self.config.mode == 'i2t':
504
+ _text = sample_fn(
505
+ self.config.mode, z=z_img,
506
+ clip_img=clip_imgs) # conditioned on the image embedding
507
+ elif self.config.mode == 't':
508
+ _text = sample_fn(self.config.mode)
509
+ elif self.config.mode == 't2i2t':
510
+ _z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
511
+ _text = sample_fn('i2t', z=_z, clip_img=_clip_img)
512
+ prompts = self.caption_decoder.generate_captions(_text)
513
+ return None, prompts[0]
514
+ else:
515
+ raise ValueError
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==0.12.0
3
+ einops==0.6.0
4
+ ftfy==6.1.1
5
+ git+https://github.com/openai/CLIP.git@a9b1bf5
6
+ gradio==3.21.0
7
+ huggingface-hub==0.13.2
8
+ ml-collections==0.1.1
9
+ torch==1.13.1
10
+ torchvision==0.14.1
11
+ transformers==4.23.1
12
+ triton==2.0.0
13
+ xformers==0.0.16
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
unidiffuser ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 390368777ce0a6102f50361ab6dae8e0991447a8