MH0386 commited on
Commit
3e165b2
·
verified ·
1 Parent(s): 054c6c0

Upload folder using huggingface_hub

Browse files
Files changed (46) hide show
  1. .dockerignore +10 -0
  2. .gitignore +17 -0
  3. .pre-commit-config.yaml +24 -0
  4. .python-version +1 -0
  5. Dockerfile +46 -0
  6. README.md +6 -8
  7. compose.yaml +26 -0
  8. pyproject.toml +56 -0
  9. qodana.yaml +44 -0
  10. src/visualizr/LIA_Model.py +52 -0
  11. src/visualizr/__init__.py +59 -0
  12. src/visualizr/__main__.py +22 -0
  13. src/visualizr/choices.py +181 -0
  14. src/visualizr/config.py +394 -0
  15. src/visualizr/config_base.py +83 -0
  16. src/visualizr/dataset.py +246 -0
  17. src/visualizr/dataset_util.py +14 -0
  18. src/visualizr/diffusion/__init__.py +9 -0
  19. src/visualizr/diffusion/base.py +1136 -0
  20. src/visualizr/diffusion/diffusion.py +183 -0
  21. src/visualizr/diffusion/resample.py +63 -0
  22. src/visualizr/dist_utils.py +43 -0
  23. src/visualizr/experiment.py +386 -0
  24. src/visualizr/face_sr/face_enhancer.py +134 -0
  25. src/visualizr/face_sr/videoio.py +14 -0
  26. src/visualizr/gui.py +97 -0
  27. src/visualizr/model/__init__.py +7 -0
  28. src/visualizr/model/base.py +28 -0
  29. src/visualizr/model/blocks.py +572 -0
  30. src/visualizr/model/diffusion.py +323 -0
  31. src/visualizr/model/latentnet.py +189 -0
  32. src/visualizr/model/nn.py +129 -0
  33. src/visualizr/model/seq2seq.py +223 -0
  34. src/visualizr/model/unet.py +561 -0
  35. src/visualizr/model/unet_autoenc.py +291 -0
  36. src/visualizr/networks/__init__.py +0 -0
  37. src/visualizr/networks/discriminator.py +300 -0
  38. src/visualizr/networks/encoder.py +432 -0
  39. src/visualizr/networks/generator.py +37 -0
  40. src/visualizr/networks/styledecoder.py +618 -0
  41. src/visualizr/networks/utils.py +49 -0
  42. src/visualizr/renderer.py +35 -0
  43. src/visualizr/settings.py +29 -0
  44. src/visualizr/templates.py +302 -0
  45. src/visualizr/utils.py +432 -0
  46. uv.lock +0 -0
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ **/__pycache__/
2
+ .deepsource.toml
3
+ .env
4
+ .github/
5
+ .idea/
6
+ .mypy_cache/
7
+ .ruff_cache/
8
+ .venv/
9
+ .vscode/
10
+ renovate.json
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/__pycache__/
2
+ **.dccache
3
+ **.env
4
+ .mypy_cache/
5
+ .ruff_cache/
6
+ .venv/
7
+ logs/
8
+ results/
9
+ ckpts/
10
+ Anitalker/
11
+ outputs/
12
+
13
+ .github/
14
+ .trunk/
15
+ .idea/
16
+ renovate.json
17
+ .deepsource.toml
.pre-commit-config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fail_fast: false
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v5.0.0
5
+ hooks:
6
+ - id: check-yaml
7
+ - id: check-toml
8
+ - id: check-json
9
+ - id: trailing-whitespace
10
+ - id: check-merge-conflict
11
+ - repo: https://github.com/astral-sh/uv-pre-commit
12
+ rev: 0.7.13
13
+ hooks:
14
+ - id: uv-lock
15
+ - repo: https://github.com/astral-sh/ruff-pre-commit
16
+ rev: v0.12.0
17
+ hooks:
18
+ - id: ruff-check
19
+ args: [ --fix ]
20
+ - id: ruff-format
21
+ - repo: https://github.com/abravalheri/validate-pyproject
22
+ rev: v0.24.1
23
+ hooks:
24
+ - id: validate-pyproject
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10 AS builder
2
+
3
+ SHELL ["/bin/bash", "-c"]
4
+
5
+ ENV UV_LINK_MODE=copy \
6
+ UV_COMPILE_BYTECODE=1 \
7
+ UV_PYTHON_DOWNLOADS=0
8
+
9
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
10
+
11
+ WORKDIR /app
12
+
13
+ RUN --mount=type=cache,target=/root/.cache/uv \
14
+ --mount=type=bind,source=uv.lock,target=uv.lock \
15
+ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
16
+ --mount=type=bind,source=README.md,target=README.md \
17
+ uv sync --no-install-project --no-dev --locked --no-editable
18
+
19
+ COPY . /app
20
+
21
+ RUN --mount=type=cache,target=/root/.cache/uv \
22
+ uv sync --no-dev --locked --no-editable
23
+
24
+ FROM python:3.10-slim AS production
25
+
26
+ SHELL ["/bin/bash", "-c"]
27
+
28
+ ENV GRADIO_SERVER_PORT=7860 \
29
+ GRADIO_SERVER_NAME=0.0.0.0
30
+ # skipcq: DOK-DL3008
31
+ RUN groupadd app && \
32
+ useradd -m -g app -s /bin/bash app && \
33
+ apt-get update -qq && \
34
+ apt-get install -qq -y --no-install-recommends espeak-ng ffmpeg && \
35
+ apt-get clean -qq && \
36
+ rm -rf /var/lib/apt/lists/*
37
+
38
+ WORKDIR /home/app
39
+
40
+ COPY --from=builder --chown=app:app /app/.venv /app/.venv
41
+
42
+ USER app
43
+
44
+ EXPOSE ${GRADIO_SERVER_PORT}
45
+
46
+ CMD ["/app/.venv/bin/vocalizr"]
README.md CHANGED
@@ -1,12 +1,10 @@
1
  ---
2
  title: Visualizr
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.34.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Visualizr
3
+ emoji: 👁️
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: docker
7
+ app_port: 7860
 
 
8
  ---
9
 
10
+ ## **Visualizr**: Video Generator part of the Chatacter Backend
compose.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Visualizr
2
+ services:
3
+ visualizr:
4
+ image: ghcr.io/alphaspheredotai/visualizr:latest
5
+ ports:
6
+ - "7860:7860"
7
+ volumes:
8
+ - visualizr_venv:/venv
9
+ restart: on-failure:3
10
+ healthcheck:
11
+ test: ["CMD", "curl", "-f", "http://localhost:7860"]
12
+ interval: 1m30s
13
+ timeout: 10s
14
+ retries: 5
15
+ start_period: 40s
16
+ start_interval: 5s
17
+ deploy:
18
+ resources:
19
+ reservations:
20
+ devices:
21
+ - driver: nvidia
22
+ count: all
23
+ capabilities: [gpu]
24
+ volumes:
25
+ visualizr_venv:
26
+ name: Python Environment for Visualizr
pyproject.toml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "visualizr"
3
+ description = "Video Generator part of the Chatacter Backend"
4
+ version = "0.1.0"
5
+ readme = "README.md"
6
+ requires-python = "~=3.10"
7
+ dependencies = [
8
+ "espnet>=202412",
9
+ "gfpgan>=1.3.8",
10
+ "gradio[mcp]>=5.34.0",
11
+ "librosa>=0.9.2",
12
+ "loguru>=0.7.3",
13
+ "moviepy>=1.0.3",
14
+ "python-speech-features>=0.6",
15
+ "pytorch-lightning>=2.5.1.post0",
16
+ "realesrgan>=0.3.0",
17
+ "spaces>=0.37.0",
18
+ "torch>=2.6.0",
19
+ "torchaudio>=2.6.0",
20
+ "torchmetrics>=1.7.1",
21
+ "torchvision>=0.21.0",
22
+ "tqdm>=4.67.1",
23
+ "transformers>=4.52.2",
24
+ ]
25
+
26
+ [project.scripts]
27
+ visualizr = "visualizr.__main__:main"
28
+
29
+ [build-system]
30
+ build-backend = "uv_build"
31
+ requires = ["uv_build"]
32
+
33
+ [dependency-groups]
34
+ dev = [
35
+ "huggingface-hub[cli,hf-transfer]>=0.33.0",
36
+ "pyrefly>=0.19.2",
37
+ "ruff>=0.11.10",
38
+ "ty>=0.0.1a6",
39
+ "watchfiles>=1.0.5",
40
+ ]
41
+
42
+ [tool.uv.sources]
43
+ torch = [
44
+ { index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
45
+ ]
46
+ torchaudio = [
47
+ { index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
48
+ ]
49
+ torchvision = [
50
+ { index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
51
+ ]
52
+
53
+ [[tool.uv.index]]
54
+ explicit = true
55
+ name = "pytorch-cu124"
56
+ url = "https://download.pytorch.org/whl/cu124"
qodana.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "1.0"
2
+ linter: jetbrains/qodana-python:2025.1
3
+ bootstrap: |
4
+ export UV_CACHE_DIR=/data/cache
5
+ export CONDA_PREFIX=/opt/miniconda3
6
+ mkdir .log
7
+ conda config --add channels defaults > .log/output.log
8
+ conda install -y python=3.10 > .log/output.log
9
+ pip install uv
10
+ uv pip sync pyproject.toml --python 3.10
11
+ rm -rd .log
12
+ exclude:
13
+ - name: All
14
+ paths:
15
+ - uv.lock
16
+ profile:
17
+ name: qodana.recommended
18
+ include:
19
+ - name: CheckDependencyLicenses
20
+ - name: PyArgumentListInspection
21
+ - name: PyTypeCheckerInspection
22
+ - name: PyDataclassInspection
23
+ - name: PyUnresolvedReferencesInspection
24
+ - name: CyclomaticComplexityInspection
25
+ - name: UnsatisfiedRequirementInspection
26
+ - name: IgnoreFileDuplicateEntry
27
+ - name: YAMLSchemaDeprecation
28
+ - name: YAMLDuplicatedKeys
29
+ - name: YAMLRecursiveAlias
30
+ - name: YAMLIncompatibleTypes
31
+ - name: YAMLUnresolvedAlias
32
+ - name: YAMLUnusedAnchor
33
+ - name: YAMLSchemaValidation
34
+ - name: CheckModuleLicenses
35
+ - name: CheckThirdPartySoftwareList
36
+ - name: Annotator
37
+ - name: EmptyDirectory
38
+ - name: InconsistentLineSeparators
39
+ - name: IncorrectFormatting
40
+ - name: LongLine
41
+ - name: ProblematicWhitespace
42
+ - name: ReassignedToPlainText
43
+ - name: RedundantSuppression
44
+ - name: TodoComment
src/visualizr/LIA_Model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import load, nn
2
+
3
+ from visualizr import logger
4
+ from visualizr.networks.encoder import Encoder
5
+ from visualizr.networks.styledecoder import Synthesis
6
+
7
+
8
+ class LIA_Model(nn.Module):
9
+ def __init__(
10
+ self,
11
+ size=256,
12
+ style_dim=512,
13
+ motion_dim=20,
14
+ channel_multiplier=1,
15
+ blur_kernel=[1, 3, 3, 1],
16
+ fusion_type="",
17
+ ):
18
+ super().__init__()
19
+ self.enc = Encoder(size, style_dim, motion_dim, fusion_type)
20
+ self.dec = Synthesis(
21
+ size, style_dim, motion_dim, blur_kernel, channel_multiplier
22
+ )
23
+
24
+ def get_start_direction_code(self, x_start, x_target, x_face, x_aug):
25
+ enc_dic = self.enc(x_start, x_target, x_face, x_aug)
26
+
27
+ wa, alpha, feats = enc_dic["h_source"], enc_dic["h_motion"], enc_dic["feats"]
28
+
29
+ return wa, alpha, feats
30
+
31
+ def render(self, start, direction, feats):
32
+ return self.dec(start, direction, feats)
33
+
34
+ def load_lightning_model(self, lia_pretrained_model_path):
35
+ selfState = self.state_dict()
36
+
37
+ state = load(lia_pretrained_model_path, map_location="cpu")
38
+ for name, param in state.items():
39
+ origName = name
40
+ if name not in selfState:
41
+ name = name.replace("lia.", "")
42
+ if name not in selfState:
43
+ logger.exception("%s is not in the model." % origName)
44
+ # You can ignore those errors as some parameters are only used for training
45
+ continue
46
+ if selfState[name].size() != state[origName].size():
47
+ logger.exception(
48
+ "Wrong parameter length: %s, model: %s, loaded: %s"
49
+ % (origName, selfState[name].size(), state[origName].size())
50
+ )
51
+ continue
52
+ selfState[name].copy_(param)
src/visualizr/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from os import getenv
3
+ from pathlib import Path
4
+
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import snapshot_download
7
+ from loguru import logger
8
+ from torch import cuda
9
+
10
+ load_dotenv()
11
+
12
+ DEBUG: bool = getenv(key="DEBUG", default="True").lower() == "true"
13
+ SERVER_NAME: str = getenv(key="GRADIO_SERVER_NAME", default="localhost")
14
+ SERVER_PORT: int = int(getenv(key="GRADIO_SERVER_PORT", default="8080"))
15
+ CURRENT_DATE: str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
16
+
17
+ BASE_DIR: Path = Path.cwd()
18
+ RESULTS_DIR: Path = BASE_DIR / "results"
19
+ LOG_DIR: Path = BASE_DIR / "logs"
20
+ CHECKPOINT_DIR: Path = BASE_DIR / "ckpts"
21
+ AUDIO_FILE_PATH: Path = RESULTS_DIR / f"{CURRENT_DATE}.wav"
22
+ LOG_FILE_PATH: Path = LOG_DIR / f"{CURRENT_DATE}.log"
23
+ CUDA_AVAILABLE: bool = cuda.is_available()
24
+
25
+ FRAMES_RESULT_SAVED_PATH: Path = RESULTS_DIR / "frames"
26
+ STAGE_1_CHECKPOINT_PATH = CHECKPOINT_DIR / "stage1.ckpt"
27
+ VIDEO_PATH = RESULTS_DIR / f"{CURRENT_DATE}.mp4"
28
+
29
+ RESULTS_DIR.mkdir(exist_ok=True)
30
+ LOG_DIR.mkdir(exist_ok=True)
31
+ CHECKPOINT_DIR.mkdir(exist_ok=True)
32
+ FRAMES_RESULT_SAVED_PATH.mkdir(exist_ok=True)
33
+
34
+ MOTION_DIM: int = 20
35
+ TMP_MP4: str = ".tmp.mp4"
36
+
37
+ logger.add(
38
+ sink=LOG_FILE_PATH,
39
+ format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
40
+ colorize=True,
41
+ )
42
+ logger.info(f"CUDA Available: {CUDA_AVAILABLE}")
43
+ logger.info(f"Current date: {CURRENT_DATE}")
44
+ logger.info(f"Base directory: {BASE_DIR}")
45
+ logger.info(f"Results directory: {RESULTS_DIR}")
46
+ logger.info(f"Log directory: {LOG_DIR}")
47
+ logger.info(f"Checkpoint directory: {CHECKPOINT_DIR}")
48
+
49
+ model_mapping: dict[str, str] = {
50
+ "mfcc_pose_only": f"{CHECKPOINT_DIR}/stage2_pose_only_mfcc.ckpt",
51
+ "mfcc_full_control": f"{CHECKPOINT_DIR}/stage2_more_controllable_mfcc.ckpt",
52
+ "hubert_audio_only": f"{CHECKPOINT_DIR}/stage2_audio_only_hubert.ckpt",
53
+ "hubert_pose_only": f"{CHECKPOINT_DIR}/stage2_pose_only_hubert.ckpt",
54
+ "hubert_full_control": f"{CHECKPOINT_DIR}/stage2_full_control_hubert.ckpt",
55
+ }
56
+
57
+ snapshot_download(
58
+ repo_id="taocode/anitalker_ckpts", local_dir=CHECKPOINT_DIR, repo_type="model"
59
+ )
src/visualizr/__main__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio import Blocks
2
+
3
+ from visualizr import DEBUG, SERVER_NAME, SERVER_PORT
4
+ from visualizr.gui import app_block
5
+
6
+
7
+ def main() -> None:
8
+ """Launch the Gradio voice generation web application."""
9
+ app: Blocks = app_block()
10
+ app.queue(api_open=True).launch(
11
+ server_name=SERVER_NAME,
12
+ server_port=SERVER_PORT,
13
+ debug=DEBUG,
14
+ mcp_server=True,
15
+ show_api=True,
16
+ enable_monitoring=True,
17
+ show_error=True,
18
+ )
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main()
src/visualizr/choices.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from torch import nn
4
+
5
+
6
+ class TrainMode(Enum):
7
+ # manipulate mode = training the classifier
8
+ manipulate = "manipulate"
9
+ # default trainin mode!
10
+ diffusion = "diffusion"
11
+ # default latent training mode!
12
+ # fitting the a DDPM to a given latent
13
+ latent_diffusion = "latentdiffusion"
14
+
15
+ def is_manipulate(self):
16
+ return self in [
17
+ TrainMode.manipulate,
18
+ ]
19
+
20
+ def is_diffusion(self):
21
+ return self in [
22
+ TrainMode.diffusion,
23
+ TrainMode.latent_diffusion,
24
+ ]
25
+
26
+ def is_autoenc(self):
27
+ # the network possibly does autoencoding
28
+ return self in [
29
+ TrainMode.diffusion,
30
+ ]
31
+
32
+ def is_latent_diffusion(self):
33
+ return self in [
34
+ TrainMode.latent_diffusion,
35
+ ]
36
+
37
+ def use_latent_net(self):
38
+ return self.is_latent_diffusion()
39
+
40
+ def require_dataset_infer(self):
41
+ """
42
+ whether training in this mode requires the latent variables to be available?
43
+ """
44
+ # this will precalculate all the latents before hand
45
+ # and the dataset will be all the predicted latents
46
+ return self in [
47
+ TrainMode.latent_diffusion,
48
+ TrainMode.manipulate,
49
+ ]
50
+
51
+
52
+ class ManipulateMode(Enum):
53
+ """
54
+ how to train the classifier to manipulate
55
+ """
56
+
57
+ # train on whole celeba attr dataset
58
+ celebahq_all = "celebahq_all"
59
+ # celeba with D2C's crop
60
+ d2c_fewshot = "d2cfewshot"
61
+ d2c_fewshot_allneg = "d2cfewshotallneg"
62
+
63
+ def is_celeba_attr(self):
64
+ return self in [
65
+ ManipulateMode.d2c_fewshot,
66
+ ManipulateMode.d2c_fewshot_allneg,
67
+ ManipulateMode.celebahq_all,
68
+ ]
69
+
70
+ def is_single_class(self):
71
+ return self in [
72
+ ManipulateMode.d2c_fewshot,
73
+ ManipulateMode.d2c_fewshot_allneg,
74
+ ]
75
+
76
+ def is_fewshot(self):
77
+ return self in [
78
+ ManipulateMode.d2c_fewshot,
79
+ ManipulateMode.d2c_fewshot_allneg,
80
+ ]
81
+
82
+ def is_fewshot_allneg(self):
83
+ return self in [
84
+ ManipulateMode.d2c_fewshot_allneg,
85
+ ]
86
+
87
+
88
+ class ModelType(Enum):
89
+ """
90
+ Kinds of the backbone models
91
+ """
92
+
93
+ # unconditional ddpm
94
+ ddpm = "ddpm"
95
+ # autoencoding ddpm cannot do unconditional generation
96
+ autoencoder = "autoencoder"
97
+
98
+ def has_autoenc(self):
99
+ return self in [
100
+ ModelType.autoencoder,
101
+ ]
102
+
103
+ def can_sample(self):
104
+ return self in [ModelType.ddpm]
105
+
106
+
107
+ class ModelName(Enum):
108
+ """
109
+ List of all supported model classes
110
+ """
111
+
112
+ beatgans_ddpm = "beatgans_ddpm"
113
+ beatgans_autoenc = "beatgans_autoenc"
114
+
115
+
116
+ class ModelMeanType(Enum):
117
+ """
118
+ Which type of output the model predicts.
119
+ """
120
+
121
+ eps = "eps" # the model predicts epsilon
122
+
123
+
124
+ class ModelVarType(Enum):
125
+ """
126
+ What is used as the model's output variance.
127
+
128
+ The LEARNED_RANGE option has been added to allow the model to predict
129
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
130
+ """
131
+
132
+ # posterior beta_t
133
+ fixed_small = "fixed_small"
134
+ # beta_t
135
+ fixed_large = "fixed_large"
136
+
137
+
138
+ class LossType(Enum):
139
+ mse = "mse" # use raw MSE loss (and KL when learning variances)
140
+ l1 = "l1"
141
+
142
+
143
+ class GenerativeType(Enum):
144
+ """
145
+ How's a sample generated
146
+ """
147
+
148
+ ddpm = "ddpm"
149
+ ddim = "ddim"
150
+
151
+
152
+ class OptimizerType(Enum):
153
+ adam = "adam"
154
+ adamw = "adamw"
155
+
156
+
157
+ class Activation(Enum):
158
+ none = "none"
159
+ relu = "relu"
160
+ lrelu = "lrelu"
161
+ silu = "silu"
162
+ tanh = "tanh"
163
+
164
+ def get_act(self):
165
+ if self == Activation.none:
166
+ return nn.Identity()
167
+ elif self == Activation.relu:
168
+ return nn.ReLU()
169
+ elif self == Activation.lrelu:
170
+ return nn.LeakyReLU(negative_slope=0.2)
171
+ elif self == Activation.silu:
172
+ return nn.SiLU()
173
+ elif self == Activation.tanh:
174
+ return nn.Tanh()
175
+ else:
176
+ raise NotImplementedError()
177
+
178
+
179
+ class ManipulateLossType(Enum):
180
+ bce = "bce"
181
+ mse = "mse"
src/visualizr/config.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from os import path
3
+ from typing import Tuple
4
+
5
+ from torch import distributed
6
+ from torch.multiprocessing import get_context
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data.distributed import DistributedSampler
9
+
10
+ from visualizr.choices import (
11
+ Activation,
12
+ GenerativeType,
13
+ LossType,
14
+ ManipulateLossType,
15
+ ManipulateMode,
16
+ ModelMeanType,
17
+ ModelName,
18
+ ModelType,
19
+ ModelVarType,
20
+ OptimizerType,
21
+ TrainMode,
22
+ )
23
+ from visualizr.config_base import BaseConfig
24
+ from visualizr.dataset import LatentDataLoader
25
+ from visualizr.diffusion.base import get_named_beta_schedule
26
+ from visualizr.diffusion.diffusion import SpacedDiffusionBeatGansConfig, space_timesteps
27
+ from visualizr.diffusion.resample import UniformSampler
28
+ from visualizr.model import BeatGANsAutoencConfig, BeatGANsUNetConfig, ModelConfig
29
+ from visualizr.model.blocks import ScaleAt
30
+ from visualizr.model.latentnet import LatentNetType, MLPSkipNetConfig
31
+
32
+
33
+ @dataclass
34
+ class PretrainConfig(BaseConfig):
35
+ name: str
36
+ path: str
37
+
38
+
39
+ @dataclass
40
+ class TrainConfig(BaseConfig):
41
+ # random seed
42
+ seed: int = 0
43
+ train_mode: TrainMode = TrainMode.diffusion
44
+ train_cond0_prob: float = 0
45
+ train_pred_xstart_detach: bool = True
46
+ train_interpolate_prob: float = 0
47
+ train_interpolate_img: bool = False
48
+ manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
49
+ manipulate_cls: str = None
50
+ manipulate_shots: int = None
51
+ manipulate_loss: ManipulateLossType = ManipulateLossType.bce
52
+ manipulate_znormalize: bool = False
53
+ manipulate_seed: int = 0
54
+ accum_batches: int = 1
55
+ autoenc_mid_attn: bool = True
56
+ batch_size: int = 16
57
+ batch_size_eval: int = None
58
+ beatgans_gen_type: GenerativeType = GenerativeType.ddim
59
+ beatgans_loss_type: LossType = LossType.mse
60
+ beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
61
+ beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
62
+ beatgans_rescale_timesteps: bool = False
63
+ latent_infer_path: str = None
64
+ latent_znormalize: bool = False
65
+ latent_gen_type: GenerativeType = GenerativeType.ddim
66
+ latent_loss_type: LossType = LossType.mse
67
+ latent_model_mean_type: ModelMeanType = ModelMeanType.eps
68
+ latent_model_var_type: ModelVarType = ModelVarType.fixed_large
69
+ latent_rescale_timesteps: bool = False
70
+ latent_T_eval: int = 1_000
71
+ latent_clip_sample: bool = False
72
+ latent_beta_scheduler: str = "linear"
73
+ beta_scheduler: str = "linear"
74
+ data_name: str = ""
75
+ data_val_name: str = None
76
+ diffusion_type: str = None
77
+ dropout: float = 0.1
78
+ ema_decay: float = 0.9999
79
+ eval_num_images: int = 5_000
80
+ eval_every_samples: int = 200_000
81
+ eval_ema_every_samples: int = 200_000
82
+ fid_use_torch: bool = True
83
+ fp16: bool = False
84
+ grad_clip: float = 1
85
+ img_size: int = 64
86
+ lr: float = 0.0001
87
+ optimizer: OptimizerType = OptimizerType.adam
88
+ weight_decay: float = 0
89
+ model_conf: ModelConfig = None
90
+ model_name: ModelName = None
91
+ model_type: ModelType = None
92
+ net_attn: Tuple[int] = None
93
+ net_beatgans_attn_head: int = 1
94
+ # not necessarily the same as the the number of style channels
95
+ net_beatgans_embed_channels: int = 512
96
+ net_resblock_updown: bool = True
97
+ net_enc_use_time: bool = False
98
+ net_enc_pool: str = "adaptivenonzero"
99
+ net_beatgans_gradient_checkpoint: bool = False
100
+ net_beatgans_resnet_two_cond: bool = False
101
+ net_beatgans_resnet_use_zero_module: bool = True
102
+ net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
103
+ net_beatgans_resnet_cond_channels: int = None
104
+ net_ch_mult: Tuple[int] = None
105
+ net_ch: int = 64
106
+ net_enc_attn: Tuple[int] = None
107
+ net_enc_k: int = None
108
+ # number of resblocks for the encoder (half-unet)
109
+ net_enc_num_res_blocks: int = 2
110
+ net_enc_channel_mult: Tuple[int] = None
111
+ net_enc_grad_checkpoint: bool = False
112
+ net_autoenc_stochastic: bool = False
113
+ net_latent_activation: Activation = Activation.silu
114
+ net_latent_channel_mult: Tuple[int] = (1, 2, 4)
115
+ net_latent_condition_bias: float = 0
116
+ net_latent_dropout: float = 0
117
+ net_latent_layers: int = None
118
+ net_latent_net_last_act: Activation = Activation.none
119
+ net_latent_net_type: LatentNetType = LatentNetType.none
120
+ net_latent_num_hid_channels: int = 1024
121
+ net_latent_num_time_layers: int = 2
122
+ net_latent_skip_layers: Tuple[int] = None
123
+ net_latent_time_emb_channels: int = 64
124
+ net_latent_use_norm: bool = False
125
+ net_latent_time_last_act: bool = False
126
+ net_num_res_blocks: int = 2
127
+ # number of resblocks for the UNET
128
+ net_num_input_res_blocks: int = None
129
+ net_enc_num_cls: int = None
130
+ num_workers: int = 4
131
+ parallel: bool = False
132
+ postfix: str = ""
133
+ sample_size: int = 64
134
+ sample_every_samples: int = 20_000
135
+ save_every_samples: int = 100_000
136
+ style_ch: int = 512
137
+ T_eval: int = 1_000
138
+ T_sampler: str = "uniform"
139
+ T: int = 1_000
140
+ total_samples: int = 10_000_000
141
+ warmup: int = 0
142
+ pretrain: PretrainConfig = None
143
+ continue_from: PretrainConfig = None
144
+ eval_programs: Tuple[str] = None
145
+ # if present load the checkpoint from this path instead
146
+ eval_path: str = None
147
+ base_dir: str = "checkpoints"
148
+ use_cache_dataset: bool = False
149
+ data_cache_dir: str = path.expanduser("~/cache")
150
+ work_cache_dir: str = path.expanduser("~/mycache")
151
+ # to be overridden
152
+ name: str = ""
153
+
154
+ def __post_init__(self):
155
+ self.batch_size_eval = self.batch_size_eval or self.batch_size
156
+ self.data_val_name = self.data_val_name or self.data_name
157
+
158
+ def scale_up_gpus(self, num_gpus, num_nodes=1):
159
+ self.eval_ema_every_samples *= num_gpus * num_nodes
160
+ self.eval_every_samples *= num_gpus * num_nodes
161
+ self.sample_every_samples *= num_gpus * num_nodes
162
+ self.batch_size *= num_gpus * num_nodes
163
+ self.batch_size_eval *= num_gpus * num_nodes
164
+ return self
165
+
166
+ @property
167
+ def batch_size_effective(self):
168
+ return self.batch_size * self.accum_batches
169
+
170
+ @property
171
+ def fid_cache(self):
172
+ # we try to use the local dirs to reduce the load over network drives
173
+ # hopefully, this would reduce the disconnection problems with sshfs
174
+ return f"{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}"
175
+
176
+ @property
177
+ def logdir(self):
178
+ return f"{self.base_dir}/{self.name}"
179
+
180
+ @property
181
+ def generate_dir(self):
182
+ # we try to use the local dirs to reduce the load over network drives
183
+ # hopefully, this would reduce the disconnection problems with sshfs
184
+ return f"{self.work_cache_dir}/gen_images/{self.name}"
185
+
186
+ def _make_diffusion_conf(self, T=None):
187
+ if self.diffusion_type == "beatgans":
188
+ # can use T < self.T for evaluation
189
+ # follows the guided-diffusion repo conventions
190
+ # t's are evenly spaced
191
+ if self.beatgans_gen_type == GenerativeType.ddpm:
192
+ section_counts = [T]
193
+ elif self.beatgans_gen_type == GenerativeType.ddim:
194
+ section_counts = f"ddim{T}"
195
+ else:
196
+ raise NotImplementedError()
197
+
198
+ return SpacedDiffusionBeatGansConfig(
199
+ gen_type=self.beatgans_gen_type,
200
+ model_type=self.model_type,
201
+ betas=get_named_beta_schedule(self.beta_scheduler, self.T),
202
+ model_mean_type=self.beatgans_model_mean_type,
203
+ model_var_type=self.beatgans_model_var_type,
204
+ loss_type=self.beatgans_loss_type,
205
+ rescale_timesteps=self.beatgans_rescale_timesteps,
206
+ use_timesteps=space_timesteps(
207
+ num_timesteps=self.T, section_counts=section_counts
208
+ ),
209
+ fp16=self.fp16,
210
+ )
211
+ else:
212
+ raise NotImplementedError()
213
+
214
+ def _make_latent_diffusion_conf(self, T=None):
215
+ # can use T < self.T for evaluation
216
+ # follows the guided-diffusion repo conventions
217
+ # t's are evenly spaced
218
+ if self.latent_gen_type == GenerativeType.ddpm:
219
+ section_counts = [T]
220
+ elif self.latent_gen_type == GenerativeType.ddim:
221
+ section_counts = f"ddim{T}"
222
+ else:
223
+ raise NotImplementedError()
224
+
225
+ return SpacedDiffusionBeatGansConfig(
226
+ train_pred_xstart_detach=self.train_pred_xstart_detach,
227
+ gen_type=self.latent_gen_type,
228
+ # latent's model is always ddpm
229
+ model_type=ModelType.ddpm,
230
+ # latent shares the beta scheduler and full T
231
+ betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
232
+ model_mean_type=self.latent_model_mean_type,
233
+ model_var_type=self.latent_model_var_type,
234
+ loss_type=self.latent_loss_type,
235
+ rescale_timesteps=self.latent_rescale_timesteps,
236
+ use_timesteps=space_timesteps(
237
+ num_timesteps=self.T, section_counts=section_counts
238
+ ),
239
+ fp16=self.fp16,
240
+ )
241
+
242
+ @property
243
+ def model_out_channels(self):
244
+ return 3
245
+
246
+ def make_T_sampler(self):
247
+ if self.T_sampler == "uniform":
248
+ return UniformSampler(self.T)
249
+ else:
250
+ raise NotImplementedError()
251
+
252
+ def make_diffusion_conf(self):
253
+ return self._make_diffusion_conf(self.T)
254
+
255
+ def make_eval_diffusion_conf(self):
256
+ return self._make_diffusion_conf(T=self.T_eval)
257
+
258
+ def make_latent_diffusion_conf(self):
259
+ return self._make_latent_diffusion_conf(T=self.T)
260
+
261
+ def make_latent_eval_diffusion_conf(self):
262
+ # latent can have different eval T
263
+ return self._make_latent_diffusion_conf(T=self.latent_T_eval)
264
+
265
+ def make_dataset(self, path=None, **kwargs):
266
+ return LatentDataLoader(
267
+ self.window_size,
268
+ self.frame_jpgs,
269
+ self.lmd_feats_prefix,
270
+ self.audio_prefix,
271
+ self.raw_audio_prefix,
272
+ self.motion_latents_prefix,
273
+ self.pose_prefix,
274
+ self.db_name,
275
+ audio_hz=self.audio_hz,
276
+ )
277
+
278
+ def make_loader(
279
+ self,
280
+ dataset,
281
+ shuffle: bool,
282
+ num_worker: bool = None,
283
+ drop_last: bool = True,
284
+ batch_size: int = None,
285
+ parallel: bool = False,
286
+ ):
287
+ if parallel and distributed.is_initialized():
288
+ # drop last to make sure that there is no added special indexes
289
+ sampler = DistributedSampler(dataset, shuffle=shuffle, drop_last=True)
290
+ else:
291
+ sampler = None
292
+ return DataLoader(
293
+ dataset,
294
+ batch_size=batch_size or self.batch_size,
295
+ sampler=sampler,
296
+ # with sampler, use the sample instead of this option
297
+ shuffle=False if sampler else shuffle,
298
+ num_workers=num_worker or self.num_workers,
299
+ pin_memory=True,
300
+ drop_last=drop_last,
301
+ multiprocessing_context=get_context("fork"),
302
+ )
303
+
304
+ def make_model_conf(self):
305
+ if self.model_name == ModelName.beatgans_ddpm:
306
+ self.model_type = ModelType.ddpm
307
+ self.model_conf = BeatGANsUNetConfig(
308
+ attention_resolutions=self.net_attn,
309
+ channel_mult=self.net_ch_mult,
310
+ conv_resample=True,
311
+ dims=2,
312
+ dropout=self.dropout,
313
+ embed_channels=self.net_beatgans_embed_channels,
314
+ image_size=self.img_size,
315
+ in_channels=3,
316
+ model_channels=self.net_ch,
317
+ num_classes=None,
318
+ num_head_channels=-1,
319
+ num_heads_upsample=-1,
320
+ num_heads=self.net_beatgans_attn_head,
321
+ num_res_blocks=self.net_num_res_blocks,
322
+ num_input_res_blocks=self.net_num_input_res_blocks,
323
+ out_channels=self.model_out_channels,
324
+ resblock_updown=self.net_resblock_updown,
325
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
326
+ use_new_attention_order=False,
327
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
328
+ resnet_use_zero_module=self.net_beatgans_resnet_use_zero_module,
329
+ )
330
+ elif self.model_name in [
331
+ ModelName.beatgans_autoenc,
332
+ ]:
333
+ cls = BeatGANsAutoencConfig
334
+ # supports both autoenc and vaeddpm
335
+ if self.model_name == ModelName.beatgans_autoenc:
336
+ self.model_type = ModelType.autoencoder
337
+ else:
338
+ raise NotImplementedError()
339
+
340
+ if self.net_latent_net_type == LatentNetType.none:
341
+ latent_net_conf = None
342
+ elif self.net_latent_net_type == LatentNetType.skip:
343
+ latent_net_conf = MLPSkipNetConfig(
344
+ num_channels=self.style_ch,
345
+ skip_layers=self.net_latent_skip_layers,
346
+ num_hid_channels=self.net_latent_num_hid_channels,
347
+ num_layers=self.net_latent_layers,
348
+ num_time_emb_channels=self.net_latent_time_emb_channels,
349
+ activation=self.net_latent_activation,
350
+ use_norm=self.net_latent_use_norm,
351
+ condition_bias=self.net_latent_condition_bias,
352
+ dropout=self.net_latent_dropout,
353
+ last_act=self.net_latent_net_last_act,
354
+ num_time_layers=self.net_latent_num_time_layers,
355
+ time_last_act=self.net_latent_time_last_act,
356
+ )
357
+ else:
358
+ raise NotImplementedError()
359
+
360
+ self.model_conf = cls(
361
+ attention_resolutions=self.net_attn,
362
+ channel_mult=self.net_ch_mult,
363
+ conv_resample=True,
364
+ dims=2,
365
+ dropout=self.dropout,
366
+ embed_channels=self.net_beatgans_embed_channels,
367
+ enc_out_channels=self.style_ch,
368
+ enc_pool=self.net_enc_pool,
369
+ enc_num_res_block=self.net_enc_num_res_blocks,
370
+ enc_channel_mult=self.net_enc_channel_mult,
371
+ enc_grad_checkpoint=self.net_enc_grad_checkpoint,
372
+ enc_attn_resolutions=self.net_enc_attn,
373
+ image_size=self.img_size,
374
+ in_channels=3,
375
+ model_channels=self.net_ch,
376
+ num_classes=None,
377
+ num_head_channels=-1,
378
+ num_heads_upsample=-1,
379
+ num_heads=self.net_beatgans_attn_head,
380
+ num_res_blocks=self.net_num_res_blocks,
381
+ num_input_res_blocks=self.net_num_input_res_blocks,
382
+ out_channels=self.model_out_channels,
383
+ resblock_updown=self.net_resblock_updown,
384
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
385
+ use_new_attention_order=False,
386
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
387
+ resnet_use_zero_module=self.net_beatgans_resnet_use_zero_module,
388
+ latent_net_conf=latent_net_conf,
389
+ resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
390
+ )
391
+ else:
392
+ raise NotImplementedError(self.model_name)
393
+
394
+ return self.model_conf
src/visualizr/config_base.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+
7
+ @dataclass
8
+ class BaseConfig:
9
+ """BaseConfig provides methods to clone itself,
10
+ inherit settings from another config, propagate settings to nested configs,
11
+ and serialize/deserialize configurations to/from JSON.
12
+ """
13
+
14
+ def clone(self):
15
+ """Return a deep copy of this configuration."""
16
+ return deepcopy(self)
17
+
18
+ def inherit(self, another):
19
+ """inherit common keys from a given config"""
20
+ common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
21
+ for k in common_keys:
22
+ setattr(self, k, getattr(another, k))
23
+
24
+ def propagate(self):
25
+ """push down the configuration to all members"""
26
+ for _, v in self.__dict__.items():
27
+ if isinstance(v, BaseConfig):
28
+ v.inherit(self)
29
+ v.propagate()
30
+
31
+ def save(self, save_path: Path):
32
+ """save config to JSON file"""
33
+ if not save_path.exists():
34
+ save_path.mkdir(parents=True, exist_ok=True)
35
+ conf = self.as_dict_jsonable()
36
+ with open(save_path, "w") as f:
37
+ json.dump(conf, f)
38
+
39
+ def load(self, load_path: Path):
40
+ """load json config"""
41
+ if not load_path.exists():
42
+ load_path.mkdir(parents=True, exist_ok=True)
43
+ with open(load_path) as f:
44
+ conf = json.load(f)
45
+ self.from_dict(conf)
46
+
47
+ def from_dict(self, config_dict, strict=False):
48
+ """Populate configuration attributes from a dictionary, optionally
49
+ enforcing strict key checking.
50
+ """
51
+ for k, v in config_dict.items():
52
+ if not hasattr(self, k):
53
+ if strict:
54
+ raise ValueError(f"loading extra '{k}'")
55
+ print(f"loading extra '{k}'")
56
+ continue
57
+ if isinstance(self.__dict__[k], BaseConfig):
58
+ self.__dict__[k].from_dict(v)
59
+ else:
60
+ self.__dict__[k] = v
61
+
62
+ def as_dict_jsonable(self):
63
+ """Convert the configuration to a JSON-serializable dictionary."""
64
+ conf = {}
65
+ for k, v in self.__dict__.items():
66
+ if isinstance(v, BaseConfig):
67
+ conf[k] = v.as_dict_jsonable()
68
+ else:
69
+ if jsonable(v):
70
+ conf[k] = v
71
+ else:
72
+ # ignore not jsonable
73
+ pass
74
+ return conf
75
+
76
+
77
+ def jsonable(x):
78
+ """Check if the object x is JSON serializable."""
79
+ try:
80
+ json.dumps(x)
81
+ return True
82
+ except TypeError:
83
+ return False
src/visualizr/dataset.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Dict
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import python_speech_features
8
+ import torchvision
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from tqdm import tqdm
12
+
13
+
14
+ class LatentDataLoader(object):
15
+ def __init__(
16
+ self,
17
+ window_size,
18
+ frame_jpgs,
19
+ lmd_feats_prefix,
20
+ audio_prefix,
21
+ raw_audio_prefix,
22
+ motion_latents_prefix,
23
+ pose_prefix,
24
+ db_name,
25
+ video_fps=25,
26
+ audio_hz=50,
27
+ size=256,
28
+ mfcc_mode=False,
29
+ ):
30
+ self.window_size = window_size
31
+ self.lmd_feats_prefix = lmd_feats_prefix
32
+ self.audio_prefix = audio_prefix
33
+ self.pose_prefix = pose_prefix
34
+ self.video_fps = video_fps
35
+ self.audio_hz = audio_hz
36
+ self.db_name = db_name
37
+ self.raw_audio_prefix = raw_audio_prefix
38
+ self.mfcc_mode = mfcc_mode
39
+
40
+ self.transform = torchvision.transforms.Compose(
41
+ [
42
+ transforms.Resize((size, size)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
45
+ ]
46
+ )
47
+
48
+ self.data = []
49
+ for db_name in ["VoxCeleb2", "HDTF"]:
50
+ db_png_path = os.path.join(frame_jpgs, db_name)
51
+ for clip_name in tqdm(os.listdir(db_png_path)):
52
+ item_dict: Dict = {}
53
+ item_dict["clip_name"] = clip_name
54
+ item_dict["frame_count"] = len(
55
+ list(os.listdir(os.path.join(frame_jpgs, db_name, clip_name)))
56
+ )
57
+ item_dict["hubert_path"] = os.path.join(
58
+ audio_prefix, db_name, clip_name + ".npy"
59
+ )
60
+ item_dict["wav_path"] = os.path.join(
61
+ raw_audio_prefix, db_name, clip_name + ".wav"
62
+ )
63
+
64
+ item_dict["yaw_pitch_roll_path"] = os.path.join(
65
+ pose_prefix,
66
+ db_name,
67
+ "raw_videos_pose_yaw_pitch_roll",
68
+ clip_name + ".npy",
69
+ )
70
+ if not os.path.exists(item_dict["yaw_pitch_roll_path"]):
71
+ print(f"{db_name}'s {clip_name} miss yaw_pitch_roll_path")
72
+ continue
73
+
74
+ item_dict["yaw_pitch_roll"] = np.load(item_dict["yaw_pitch_roll_path"])
75
+ item_dict["yaw_pitch_roll"] = (
76
+ np.clip(item_dict["yaw_pitch_roll"], -90, 90) / 90.0
77
+ )
78
+
79
+ if not os.path.exists(item_dict["wav_path"]):
80
+ print(f"{db_name}'s {clip_name} miss wav_path")
81
+ continue
82
+
83
+ if not os.path.exists(item_dict["hubert_path"]):
84
+ print(f"{db_name}'s {clip_name} miss hubert_path")
85
+ continue
86
+
87
+ if self.mfcc_mode:
88
+ wav, sr = librosa.load(item_dict["wav_path"], sr=16000)
89
+ input_values = python_speech_features.mfcc(
90
+ signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01
91
+ )
92
+ d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
93
+ d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
94
+ input_values = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
95
+ item_dict["hubert_obj"] = input_values
96
+ else:
97
+ item_dict["hubert_obj"] = np.load(
98
+ item_dict["hubert_path"], mmap_mode="r"
99
+ )
100
+ item_dict["lmd_path"] = os.path.join(
101
+ lmd_feats_prefix, db_name, clip_name + ".txt"
102
+ )
103
+ item_dict["lmd_obj_full"] = self.read_landmark_info(
104
+ item_dict["lmd_path"], upper_face=False
105
+ )
106
+
107
+ motion_start_path = os.path.join(
108
+ motion_latents_prefix, db_name, "motions", clip_name + ".npy"
109
+ )
110
+ motion_direction_path = os.path.join(
111
+ motion_latents_prefix, db_name, "directions", clip_name + ".npy"
112
+ )
113
+
114
+ if not os.path.exists(motion_start_path):
115
+ print(f"{db_name}'s {clip_name} miss motion_start_path")
116
+ continue
117
+ if not os.path.exists(motion_direction_path):
118
+ print(f"{db_name}'s {clip_name} miss motion_direction_path")
119
+ continue
120
+
121
+ item_dict["motion_start_obj"] = np.load(motion_start_path)
122
+ item_dict["motion_direction_obj"] = np.load(motion_direction_path)
123
+
124
+ if self.mfcc_mode:
125
+ min_len = min(
126
+ item_dict["lmd_obj_full"].shape[0],
127
+ item_dict["yaw_pitch_roll"].shape[0],
128
+ item_dict["motion_start_obj"].shape[0],
129
+ item_dict["motion_direction_obj"].shape[0],
130
+ int(item_dict["hubert_obj"].shape[0] / 4),
131
+ item_dict["frame_count"],
132
+ )
133
+ item_dict["frame_count"] = min_len
134
+ item_dict["hubert_obj"] = item_dict["hubert_obj"][: min_len * 4, :]
135
+ else:
136
+ min_len = min(
137
+ item_dict["lmd_obj_full"].shape[0],
138
+ item_dict["yaw_pitch_roll"].shape[0],
139
+ item_dict["motion_start_obj"].shape[0],
140
+ item_dict["motion_direction_obj"].shape[0],
141
+ int(item_dict["hubert_obj"].shape[1] / 2),
142
+ item_dict["frame_count"],
143
+ )
144
+
145
+ item_dict["frame_count"] = min_len
146
+ item_dict["hubert_obj"] = item_dict["hubert_obj"][
147
+ :, : min_len * 2, :
148
+ ]
149
+
150
+ if min_len < self.window_size * self.video_fps + 5:
151
+ continue
152
+
153
+ print("Db count:", len(self.data))
154
+
155
+ def get_single_image(self, image_path):
156
+ img_source = Image.open(image_path).convert("RGB")
157
+ img_source = self.transform(img_source)
158
+ return img_source
159
+
160
+ def get_multiple_ranges(self, lists, multi_ranges):
161
+ # Ensure that multi_ranges is a list of tuples
162
+ if not all(isinstance(item, tuple) and len(item) == 2 for item in multi_ranges):
163
+ raise ValueError(
164
+ "multi_ranges must be a list of (start, end) tuples with exactly two elements each"
165
+ )
166
+ extracted_elements = [lists[start:end] for start, end in multi_ranges]
167
+ return [item for sublist in extracted_elements for item in sublist]
168
+
169
+ def read_landmark_info(self, lmd_path, upper_face=True):
170
+ with open(lmd_path, "r") as file:
171
+ lmd_lines = file.readlines()
172
+ lmd_lines.sort()
173
+
174
+ total_lmd_obj = []
175
+ for i, line in enumerate(lmd_lines):
176
+ # Split the coordinates and filter out any empty strings
177
+ coords = [c for c in line.strip().split(" ") if c]
178
+ coords = coords[1:] # do not include the file name in the first row
179
+ lmd_obj = []
180
+ if upper_face:
181
+ # Ensure that the coordinates are parsed as integers
182
+ for coord_pair in self.get_multiple_ranges(
183
+ coords, [(0, 3), (14, 27), (36, 48)]
184
+ ): # 28个
185
+ x, y = coord_pair.split("_")
186
+ lmd_obj.append((int(x) / 512, int(y) / 512))
187
+ else:
188
+ for coord_pair in coords:
189
+ x, y = coord_pair.split("_")
190
+ lmd_obj.append((int(x) / 512, int(y) / 512))
191
+ total_lmd_obj.append(lmd_obj)
192
+
193
+ return np.array(total_lmd_obj, dtype=np.float32)
194
+
195
+ def calculate_face_height(self, landmarks):
196
+ forehead_center = (landmarks[:, 21, :] + landmarks[:, 22, :]) / 2
197
+ chin_bottom = landmarks[:, 8, :]
198
+ distances = np.linalg.norm(forehead_center - chin_bottom, axis=1, keepdims=True)
199
+ return distances
200
+
201
+ def __getitem__(self, index):
202
+ data_item = self.data[index]
203
+ hubert_obj = data_item["hubert_obj"]
204
+ frame_count = data_item["frame_count"]
205
+ lmd_obj_full = data_item["lmd_obj_full"]
206
+ yaw_pitch_roll = data_item["yaw_pitch_roll"]
207
+ motion_start_obj = data_item["motion_start_obj"]
208
+ motion_direction_obj = data_item["motion_direction_obj"]
209
+
210
+ frame_end_index = random.randint(
211
+ self.window_size * self.video_fps + 1, frame_count - 1
212
+ )
213
+ frame_start_index = frame_end_index - self.window_size * self.video_fps
214
+ frame_hint_index = frame_start_index - 1
215
+
216
+ audio_start_index = int(frame_start_index * (self.audio_hz / self.video_fps))
217
+ audio_end_index = int(frame_end_index * (self.audio_hz / self.video_fps))
218
+
219
+ if self.mfcc_mode:
220
+ audio_feats = hubert_obj[audio_start_index:audio_end_index, :]
221
+ else:
222
+ audio_feats = hubert_obj[:, audio_start_index:audio_end_index, :]
223
+
224
+ lmd_obj_full = lmd_obj_full[frame_hint_index:frame_end_index, :]
225
+
226
+ yaw_pitch_roll = yaw_pitch_roll[frame_start_index:frame_end_index, :]
227
+
228
+ motion_start = motion_start_obj[frame_hint_index]
229
+ motion_direction_start = motion_direction_obj[frame_hint_index]
230
+ motion_direction = motion_direction_obj[frame_start_index:frame_end_index, :]
231
+
232
+ return {
233
+ "motion_start": motion_start,
234
+ "motion_direction": motion_direction,
235
+ "audio_feats": audio_feats,
236
+ # '1:' means taking the first frame as the driven frame.
237
+ # '30' is the noise location,
238
+ # '0' means x coordinate
239
+ "face_location": lmd_obj_full[1:, 30, 0],
240
+ "face_scale": self.calculate_face_height(lmd_obj_full[1:, :, :]),
241
+ "yaw_pitch_roll": yaw_pitch_roll,
242
+ "motion_direction_start": motion_direction_start,
243
+ }
244
+
245
+ def __len__(self):
246
+ return len(self.data)
src/visualizr/dataset_util.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ from shutil import copytree
3
+
4
+ from visualizr import logger
5
+ from visualizr.dist_utils import barrier, get_rank
6
+
7
+
8
+ def use_cached_dataset_path(source_path, cache_path):
9
+ if get_rank() == 0:
10
+ if not path.exists(cache_path):
11
+ logger.info(f"copying the data: {source_path} to {cache_path}")
12
+ copytree(source_path, cache_path)
13
+ barrier()
14
+ return cache_path
src/visualizr/diffusion/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from visualizr.diffusion.diffusion import (
4
+ SpacedDiffusionBeatGans,
5
+ SpacedDiffusionBeatGansConfig,
6
+ )
7
+
8
+ Sampler = Union[SpacedDiffusionBeatGans]
9
+ SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
src/visualizr/diffusion/base.py ADDED
@@ -0,0 +1,1136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import NamedTuple, Tuple
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ from torch.cuda.amp import autocast
8
+
9
+ from visualizr.choices import (
10
+ GenerativeType,
11
+ LossType,
12
+ ModelMeanType,
13
+ ModelType,
14
+ ModelVarType,
15
+ )
16
+ from visualizr.config_base import BaseConfig
17
+ from visualizr.model import Model
18
+ from visualizr.model.nn import mean_flat
19
+
20
+
21
+ @dataclass
22
+ class GaussianDiffusionBeatGansConfig(BaseConfig):
23
+ gen_type: GenerativeType
24
+ betas: Tuple[float]
25
+ model_type: ModelType
26
+ model_mean_type: ModelMeanType
27
+ model_var_type: ModelVarType
28
+ loss_type: LossType
29
+ rescale_timesteps: bool
30
+ fp16: bool
31
+ train_pred_xstart_detach: bool = True
32
+
33
+ def make_sampler(self):
34
+ return GaussianDiffusionBeatGans(self)
35
+
36
+
37
+ class GaussianDiffusionBeatGans:
38
+ """
39
+ Utilities for training and sampling diffusion models.
40
+
41
+ :param betas: A 1-D numpy array of betas for each diffusion timestep,
42
+ starting at T and going to 1.
43
+ :param model_mean_type: A ModelMeanType determining what the model outputs.
44
+ :param model_var_type: A ModelVarType determining how variance is output.
45
+ :param loss_type: A LossType determining the loss function to use.
46
+ :param rescale_timesteps: If True, pass floating point timesteps into the
47
+ model so that they are always scaled like in the
48
+ original paper (0 to 1000).
49
+ """
50
+
51
+ def __init__(self, conf: GaussianDiffusionBeatGansConfig):
52
+ self.conf = conf
53
+ self.model_mean_type = conf.model_mean_type
54
+ self.model_var_type = conf.model_var_type
55
+ self.loss_type = conf.loss_type
56
+ self.rescale_timesteps = conf.rescale_timesteps
57
+
58
+ # Use float64 for accuracy.
59
+ betas = np.array(conf.betas, dtype=np.float64)
60
+ self.betas = betas
61
+ assert len(betas.shape) == 1, "betas must be 1-D"
62
+ assert (betas > 0).all() and (betas <= 1).all()
63
+
64
+ self.num_timesteps = int(betas.shape[0])
65
+
66
+ alphas = 1.0 - betas
67
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
68
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
69
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
70
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
71
+
72
+ # calculations for diffusion q(x_t | x_{t-1}) and others
73
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
74
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
75
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
76
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
77
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
78
+
79
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
80
+ self.posterior_variance = (
81
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
82
+ )
83
+ # log calculation clipped because the posterior variance is 0 at the
84
+ # beginning of the diffusion chain.
85
+ self.posterior_log_variance_clipped = np.log(
86
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
87
+ )
88
+ self.posterior_mean_coef1 = (
89
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
90
+ )
91
+ self.posterior_mean_coef2 = (
92
+ (1.0 - self.alphas_cumprod_prev)
93
+ * np.sqrt(alphas)
94
+ / (1.0 - self.alphas_cumprod)
95
+ )
96
+
97
+ def training_losses(
98
+ self,
99
+ model,
100
+ motion_direction_start: th.Tensor,
101
+ motion_target: th.Tensor,
102
+ motion_start: th.Tensor,
103
+ audio_feats: th.Tensor,
104
+ face_location: th.Tensor,
105
+ face_scale: th.Tensor,
106
+ yaw_pitch_roll: th.Tensor,
107
+ t: th.Tensor,
108
+ model_kwargs=None,
109
+ noise: th.Tensor = None,
110
+ ):
111
+ """
112
+ Compute training losses for a single timestep.
113
+
114
+ :param model: the model to evaluate loss on.
115
+ :param x_start: the [N x C x ...] tensor of inputs.
116
+ :param t: a batch of timestep indices.
117
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
118
+ pass to the model. This can be used for conditioning.
119
+ :param noise: if specified, the specific Gaussian noise to try to remove.
120
+ :return: a dict with the key "loss" containing a tensor of shape [N].
121
+ Some mean or variance settings may also have other keys.
122
+ """
123
+ if model_kwargs is None:
124
+ model_kwargs = {}
125
+ if noise is None:
126
+ noise = th.randn_like(motion_target)
127
+
128
+ x_t = self.q_sample(motion_target, t, noise=noise)
129
+
130
+ terms = {"x_t": x_t}
131
+
132
+ if self.loss_type in [
133
+ LossType.mse,
134
+ LossType.l1,
135
+ ]:
136
+ with autocast(self.conf.fp16):
137
+ # x_t is static wrt. to the diffusion process
138
+ (
139
+ predicted_direction,
140
+ predicted_location,
141
+ predicted_scale,
142
+ predicted_pose,
143
+ ) = model.forward(
144
+ motion_start,
145
+ motion_direction_start,
146
+ audio_feats,
147
+ face_location,
148
+ face_scale,
149
+ yaw_pitch_roll,
150
+ x_t.detach(),
151
+ self._scale_timesteps(t),
152
+ control_flag=False,
153
+ )
154
+
155
+ target_types = {
156
+ ModelMeanType.eps: noise,
157
+ }
158
+ target = target_types[self.model_mean_type]
159
+ assert predicted_direction.shape == target.shape == motion_target.shape
160
+
161
+ if self.loss_type == LossType.mse:
162
+ if self.model_mean_type == ModelMeanType.eps:
163
+ direction_loss = mean_flat((target - predicted_direction) ** 2)
164
+ # import pdb;pdb.set_trace()
165
+ location_loss = mean_flat(
166
+ (face_location.unsqueeze(-1) - predicted_location) ** 2
167
+ )
168
+ scale_loss = mean_flat((face_scale - predicted_scale) ** 2)
169
+ pose_loss = mean_flat((yaw_pitch_roll - predicted_pose) ** 2)
170
+
171
+ terms["mse"] = (
172
+ direction_loss + location_loss + scale_loss + pose_loss
173
+ )
174
+
175
+ else:
176
+ raise NotImplementedError()
177
+ elif self.loss_type == LossType.l1:
178
+ # (n, c, h, w) => (n, )
179
+ terms["mse"] = mean_flat((target - predicted_direction).abs())
180
+ else:
181
+ raise NotImplementedError()
182
+
183
+ if "vb" in terms:
184
+ # if learning the variance also use the vlb loss
185
+ terms["loss"] = terms["mse"] + terms["vb"]
186
+ else:
187
+ terms["loss"] = terms["mse"]
188
+ else:
189
+ raise NotImplementedError(self.loss_type)
190
+
191
+ return terms
192
+
193
+ def sample(
194
+ self,
195
+ model: Model,
196
+ shape=None,
197
+ noise=None,
198
+ cond=None,
199
+ x_start=None,
200
+ clip_denoised=True,
201
+ model_kwargs=None,
202
+ progress=False,
203
+ ):
204
+ """
205
+ Args:
206
+ x_start: given for the autoencoder
207
+ """
208
+ if model_kwargs is None:
209
+ model_kwargs = {}
210
+ if self.conf.model_type.has_autoenc():
211
+ model_kwargs["x_start"] = x_start
212
+ model_kwargs["cond"] = cond
213
+
214
+ if self.conf.gen_type == GenerativeType.ddpm:
215
+ return self.p_sample_loop(
216
+ model,
217
+ shape=shape,
218
+ noise=noise,
219
+ clip_denoised=clip_denoised,
220
+ model_kwargs=model_kwargs,
221
+ progress=progress,
222
+ )
223
+ elif self.conf.gen_type == GenerativeType.ddim:
224
+ return self.ddim_sample_loop(
225
+ model,
226
+ shape=shape,
227
+ noise=noise,
228
+ clip_denoised=clip_denoised,
229
+ model_kwargs=model_kwargs,
230
+ progress=progress,
231
+ )
232
+ else:
233
+ raise NotImplementedError()
234
+
235
+ def q_mean_variance(self, x_start, t):
236
+ """
237
+ Get the distribution q(x_t | x_0).
238
+
239
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
240
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
241
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
242
+ """
243
+ mean = (
244
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
245
+ )
246
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
247
+ log_variance = _extract_into_tensor(
248
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
249
+ )
250
+ return mean, variance, log_variance
251
+
252
+ def q_sample(self, x_start, t, noise=None):
253
+ """
254
+ Diffuse the data for a given number of diffusion steps.
255
+
256
+ In other words, sample from q(x_t | x_0).
257
+
258
+ :param x_start: the initial data batch.
259
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
260
+ :param noise: if specified, the split-out normal noise.
261
+ :return: A noisy version of x_start.
262
+ """
263
+ if noise is None:
264
+ noise = th.randn_like(x_start)
265
+ assert noise.shape == x_start.shape
266
+ return (
267
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
268
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
269
+ * noise
270
+ )
271
+
272
+ def q_posterior_mean_variance(self, x_start, x_t, t):
273
+ """
274
+ Compute the mean and variance of the diffusion posterior:
275
+
276
+ q(x_{t-1} | x_t, x_0)
277
+
278
+ """
279
+ assert x_start.shape == x_t.shape
280
+ posterior_mean = (
281
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
282
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
283
+ )
284
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
285
+ posterior_log_variance_clipped = _extract_into_tensor(
286
+ self.posterior_log_variance_clipped, t, x_t.shape
287
+ )
288
+ assert (
289
+ posterior_mean.shape[0]
290
+ == posterior_variance.shape[0]
291
+ == posterior_log_variance_clipped.shape[0]
292
+ == x_start.shape[0]
293
+ )
294
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
295
+
296
+ def p_mean_variance(
297
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
298
+ ):
299
+ """
300
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
301
+ the initial x, x_0.
302
+
303
+ :param model: the model, which takes a signal and a batch of timesteps
304
+ as input.
305
+ :param x: the [N x C x ...] tensor at time t.
306
+ :param t: a 1-D Tensor of timesteps.
307
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
308
+ :param denoised_fn: if not None, a function which applies to the
309
+ x_start prediction before it is used to sample. Applies before
310
+ clip_denoised.
311
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
312
+ pass to the model. This can be used for conditioning.
313
+ :return: a dict with the following keys:
314
+ - 'mean': the model mean output.
315
+ - 'variance': the model variance output.
316
+ - 'log_variance': the log of 'variance'.
317
+ - 'pred_xstart': the prediction for x_0.
318
+ """
319
+ global model_log_variance, model_variance
320
+ if model_kwargs is None:
321
+ model_kwargs = {}
322
+
323
+ motion_start = model_kwargs["start"]
324
+ audio_feats = model_kwargs["audio_driven"]
325
+ face_location = model_kwargs["face_location"]
326
+ face_scale = model_kwargs["face_scale"]
327
+ yaw_pitch_roll = model_kwargs["yaw_pitch_roll"]
328
+ motion_direction_start = model_kwargs["motion_direction_start"]
329
+ control_flag = model_kwargs["control_flag"]
330
+
331
+ B, C = x.shape[:2]
332
+ assert t.shape == (B,)
333
+ with autocast(self.conf.fp16):
334
+ model_forward, _, _, _ = model.forward(
335
+ motion_start,
336
+ motion_direction_start,
337
+ audio_feats,
338
+ face_location,
339
+ face_scale,
340
+ yaw_pitch_roll,
341
+ x,
342
+ self._scale_timesteps(t),
343
+ control_flag,
344
+ )
345
+ model_output = model_forward
346
+
347
+ if self.model_var_type in [ModelVarType.fixed_large, ModelVarType.fixed_small]:
348
+ model_variance, model_log_variance = {
349
+ # for fixedlarge, we set the initial (log-)variance like so
350
+ # to get a better decoder log likelihood.
351
+ ModelVarType.fixed_large: (
352
+ np.append(self.posterior_variance[1], self.betas[1:]),
353
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
354
+ ),
355
+ ModelVarType.fixed_small: (
356
+ self.posterior_variance,
357
+ self.posterior_log_variance_clipped,
358
+ ),
359
+ }[self.model_var_type]
360
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
361
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
362
+
363
+ def process_xstart(x):
364
+ if denoised_fn is not None:
365
+ x = denoised_fn(x)
366
+ if clip_denoised:
367
+ return x.clamp(-1, 1)
368
+ return x
369
+
370
+ if self.model_mean_type in [
371
+ ModelMeanType.eps,
372
+ ]:
373
+ if self.model_mean_type == ModelMeanType.eps:
374
+ pred_xstart = process_xstart(
375
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
376
+ )
377
+ else:
378
+ raise NotImplementedError()
379
+ model_mean, _, _ = self.q_posterior_mean_variance(
380
+ x_start=pred_xstart, x_t=x, t=t
381
+ )
382
+ else:
383
+ raise NotImplementedError(self.model_mean_type)
384
+
385
+ assert (
386
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
387
+ )
388
+ return {
389
+ "mean": model_mean,
390
+ "variance": model_variance,
391
+ "log_variance": model_log_variance,
392
+ "pred_xstart": pred_xstart,
393
+ "model_forward": model_forward,
394
+ }
395
+
396
+ def _predict_xstart_from_eps(self, x_t, t, eps):
397
+ assert x_t.shape == eps.shape
398
+ return (
399
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
400
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
401
+ )
402
+
403
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
404
+ assert x_t.shape == xprev.shape
405
+ return ( # (xprev - coef2*x_t) / coef1
406
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
407
+ - _extract_into_tensor(
408
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
409
+ )
410
+ * x_t
411
+ )
412
+
413
+ def _predict_xstart_from_scaled_xstart(self, t, scaled_xstart):
414
+ return scaled_xstart * _extract_into_tensor(
415
+ self.sqrt_recip_alphas_cumprod, t, scaled_xstart.shape
416
+ )
417
+
418
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
419
+ return (
420
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
421
+ - pred_xstart
422
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
423
+
424
+ def _predict_eps_from_scaled_xstart(self, x_t, t, scaled_xstart):
425
+ """
426
+ Args:
427
+ scaled_xstart: is supposed to be sqrt(alphacum) * x_0
428
+ """
429
+ # 1 / sqrt(1-alphabar) * (x_t - scaled xstart)
430
+ return (x_t - scaled_xstart) / _extract_into_tensor(
431
+ self.sqrt_one_minus_alphas_cumprod, t, x_t.shape
432
+ )
433
+
434
+ def _scale_timesteps(self, t):
435
+ if self.rescale_timesteps:
436
+ # scale t to be maxed out at 1000 steps
437
+ return t.float() * (1000.0 / self.num_timesteps)
438
+ return t
439
+
440
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
441
+ """
442
+ Compute the mean for the previous step, given a function cond_fn that
443
+ computes the gradient of a conditional log probability with respect to
444
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
445
+ condition on y.
446
+
447
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
448
+ """
449
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
450
+ new_mean = (
451
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
452
+ )
453
+ return new_mean
454
+
455
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
456
+ """
457
+ Compute what the p_mean_variance output would have been, should the
458
+ model's score function be conditioned by cond_fn.
459
+
460
+ See condition_mean() for details on cond_fn.
461
+
462
+ Unlike condition_mean(), this instead uses the conditioning strategy
463
+ from Song et al (2020).
464
+ """
465
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
466
+
467
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
468
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
469
+ x, self._scale_timesteps(t), **model_kwargs
470
+ )
471
+
472
+ out = p_mean_var.copy()
473
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
474
+ out["mean"], _, _ = self.q_posterior_mean_variance(
475
+ x_start=out["pred_xstart"], x_t=x, t=t
476
+ )
477
+ return out
478
+
479
+ def p_sample(
480
+ self,
481
+ model: Model,
482
+ x,
483
+ t,
484
+ clip_denoised=True,
485
+ denoised_fn=None,
486
+ cond_fn=None,
487
+ model_kwargs=None,
488
+ ):
489
+ """
490
+ Sample x_{t-1} from the model at the given timestep.
491
+
492
+ :param model: the model to sample from.
493
+ :param x: the current tensor at x_{t-1}.
494
+ :param t: the value of t, starting at 0 for the first diffusion step.
495
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
496
+ :param denoised_fn: if not None, a function which applies to the
497
+ x_start prediction before it is used to sample.
498
+ :param cond_fn: if not None, this is a gradient function that acts
499
+ similarly to the model.
500
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
501
+ pass to the model. This can be used for conditioning.
502
+ :return: a dict containing the following keys:
503
+ - 'sample': a random sample from the model.
504
+ - 'pred_xstart': a prediction of x_0.
505
+ """
506
+ out = self.p_mean_variance(
507
+ model,
508
+ x,
509
+ t,
510
+ clip_denoised=clip_denoised,
511
+ denoised_fn=denoised_fn,
512
+ model_kwargs=model_kwargs,
513
+ )
514
+ noise = th.randn_like(x)
515
+ nonzero_mask = (
516
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
517
+ ) # no noise when t == 0
518
+ if cond_fn is not None:
519
+ out["mean"] = self.condition_mean(
520
+ cond_fn, out, x, t, model_kwargs=model_kwargs
521
+ )
522
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
523
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
524
+
525
+ def p_sample_loop(
526
+ self,
527
+ model: Model,
528
+ shape=None,
529
+ noise=None,
530
+ clip_denoised=True,
531
+ denoised_fn=None,
532
+ cond_fn=None,
533
+ model_kwargs=None,
534
+ device=None,
535
+ progress=False,
536
+ ):
537
+ """
538
+ Generate samples from the model.
539
+
540
+ :param model: the model module.
541
+ :param shape: the shape of the samples, (N, C, H, W).
542
+ :param noise: if specified, the noise from the encoder to sample.
543
+ Should be of the same shape as `shape`.
544
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
545
+ :param denoised_fn: if not None, a function which applies to the
546
+ x_start prediction before it is used to sample.
547
+ :param cond_fn: if not None, this is a gradient function that acts
548
+ similarly to the model.
549
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
550
+ pass to the model. This can be used for conditioning.
551
+ :param device: if specified, the device to create the samples on.
552
+ If not specified, use a model parameter's device.
553
+ :param progress: if True, show a tqdm progress bar.
554
+ :return: a non-differentiable batch of samples.
555
+ """
556
+ final = None
557
+ for sample in self.p_sample_loop_progressive(
558
+ model,
559
+ shape,
560
+ noise=noise,
561
+ clip_denoised=clip_denoised,
562
+ denoised_fn=denoised_fn,
563
+ cond_fn=cond_fn,
564
+ model_kwargs=model_kwargs,
565
+ device=device,
566
+ progress=progress,
567
+ ):
568
+ final = sample
569
+ return final["sample"]
570
+
571
+ def p_sample_loop_progressive(
572
+ self,
573
+ model: Model,
574
+ shape=None,
575
+ noise=None,
576
+ clip_denoised=True,
577
+ denoised_fn=None,
578
+ cond_fn=None,
579
+ model_kwargs=None,
580
+ device=None,
581
+ progress=False,
582
+ ):
583
+ """
584
+ Generate samples from the model and yield intermediate samples from
585
+ each timestep of diffusion.
586
+
587
+ Arguments are the same as p_sample_loop().
588
+ Returns a generator over dicts, where each dict is the return value of
589
+ p_sample().
590
+ """
591
+ if device is None:
592
+ device = next(model.parameters()).device
593
+ if noise is not None:
594
+ img = noise
595
+ else:
596
+ assert isinstance(shape, (tuple, list))
597
+ img = th.randn(*shape, device=device)
598
+ indices = list(range(self.num_timesteps))[::-1]
599
+
600
+ if progress:
601
+ # Lazy import so that we don't depend on tqdm.
602
+ from tqdm.auto import tqdm
603
+
604
+ indices = tqdm(indices)
605
+
606
+ for i in indices:
607
+ # t = th.tensor([i] * shape[0], device=device)
608
+ t = th.tensor([i] * len(img), device=device)
609
+ with th.no_grad():
610
+ out = self.p_sample(
611
+ model,
612
+ img,
613
+ t,
614
+ clip_denoised=clip_denoised,
615
+ denoised_fn=denoised_fn,
616
+ cond_fn=cond_fn,
617
+ model_kwargs=model_kwargs,
618
+ )
619
+ yield out
620
+ img = out["sample"]
621
+
622
+ def ddim_sample(
623
+ self,
624
+ model: Model,
625
+ x,
626
+ t,
627
+ clip_denoised=True,
628
+ denoised_fn=None,
629
+ cond_fn=None,
630
+ model_kwargs=None,
631
+ eta=0.0,
632
+ ):
633
+ """
634
+ Sample x_{t-1} from the model using DDIM.
635
+
636
+ Same usage as p_sample().
637
+ """
638
+ out = self.p_mean_variance(
639
+ model,
640
+ x,
641
+ t,
642
+ clip_denoised=clip_denoised,
643
+ denoised_fn=denoised_fn,
644
+ model_kwargs=model_kwargs,
645
+ )
646
+ if cond_fn is not None:
647
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
648
+
649
+ # Usually our model outputs epsilon, but we re-derive it
650
+ # in case we used x_start or x_prev prediction.
651
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
652
+
653
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
654
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
655
+ sigma = (
656
+ eta
657
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
658
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
659
+ )
660
+ # Equation 12.
661
+ noise = th.randn_like(x)
662
+ mean_pred = (
663
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
664
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
665
+ )
666
+ nonzero_mask = (
667
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
668
+ ) # no noise when t == 0
669
+ sample = mean_pred + nonzero_mask * sigma * noise
670
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
671
+
672
+ def ddim_reverse_sample(
673
+ self,
674
+ model: Model,
675
+ x,
676
+ t,
677
+ clip_denoised=True,
678
+ denoised_fn=None,
679
+ model_kwargs=None,
680
+ eta=0.0,
681
+ ):
682
+ """
683
+ Sample x_{t+1} from the model using DDIM reverse ODE.
684
+ NOTE: never used ?
685
+ """
686
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
687
+ out = self.p_mean_variance(
688
+ model,
689
+ x,
690
+ t,
691
+ clip_denoised=clip_denoised,
692
+ denoised_fn=denoised_fn,
693
+ model_kwargs=model_kwargs,
694
+ )
695
+ # Usually our model outputs epsilon, but we re-derive it
696
+ # in case we used x_start or x_prev prediction.
697
+ eps = (
698
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
699
+ - out["pred_xstart"]
700
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
701
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
702
+
703
+ # Equation 12. reversed (DDIM paper) (th.sqrt == torch.sqrt)
704
+ mean_pred = (
705
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
706
+ + th.sqrt(1 - alpha_bar_next) * eps
707
+ )
708
+
709
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
710
+
711
+ def ddim_reverse_sample_loop(
712
+ self,
713
+ model: Model,
714
+ x,
715
+ clip_denoised=True,
716
+ denoised_fn=None,
717
+ model_kwargs=None,
718
+ eta=0.0,
719
+ device=None,
720
+ ):
721
+ if device is None:
722
+ device = next(model.parameters()).device
723
+ sample_t = []
724
+ xstart_t = []
725
+ T = []
726
+ indices = list(range(self.num_timesteps))
727
+ sample = x
728
+ for i in indices:
729
+ t = th.tensor([i] * len(sample), device=device)
730
+ with th.no_grad():
731
+ out = self.ddim_reverse_sample(
732
+ model,
733
+ sample,
734
+ t=t,
735
+ clip_denoised=clip_denoised,
736
+ denoised_fn=denoised_fn,
737
+ model_kwargs=model_kwargs,
738
+ eta=eta,
739
+ )
740
+ sample = out["sample"]
741
+ # [1, ..., T]
742
+ sample_t.append(sample)
743
+ # [0, ...., T-1]
744
+ xstart_t.append(out["pred_xstart"])
745
+ # [0, ..., T-1] ready to use
746
+ T.append(t)
747
+
748
+ return {
749
+ # xT "
750
+ "sample": sample,
751
+ # (1, ..., T)
752
+ "sample_t": sample_t,
753
+ # xstart here is a bit different from sampling from T = T-1 to T = 0
754
+ # may not be exact
755
+ "xstart_t": xstart_t,
756
+ "T": T,
757
+ }
758
+
759
+ def ddim_sample_loop(
760
+ self,
761
+ model: Model,
762
+ shape=None,
763
+ noise=None,
764
+ clip_denoised=True,
765
+ denoised_fn=None,
766
+ cond_fn=None,
767
+ model_kwargs=None,
768
+ device=None,
769
+ progress=False,
770
+ eta=0.0,
771
+ ):
772
+ """
773
+ Generate samples from the model using DDIM.
774
+
775
+ Same usage as p_sample_loop().
776
+ """
777
+ final = None
778
+ for sample in self.ddim_sample_loop_progressive(
779
+ model,
780
+ shape,
781
+ noise=noise,
782
+ clip_denoised=clip_denoised,
783
+ denoised_fn=denoised_fn,
784
+ cond_fn=cond_fn,
785
+ model_kwargs=model_kwargs,
786
+ device=device,
787
+ progress=progress,
788
+ eta=eta,
789
+ ):
790
+ final = sample
791
+ return final["sample"]
792
+
793
+ def ddim_sample_loop_progressive(
794
+ self,
795
+ model: Model,
796
+ shape=None,
797
+ noise=None,
798
+ clip_denoised=True,
799
+ denoised_fn=None,
800
+ cond_fn=None,
801
+ model_kwargs=None,
802
+ device=None,
803
+ progress=False,
804
+ eta=0.0,
805
+ ):
806
+ """
807
+ Use DDIM to sample from the model and yield intermediate samples from
808
+ each timestep of DDIM.
809
+
810
+ Same usage as p_sample_loop_progressive().
811
+ """
812
+ if device is None:
813
+ device = next(model.parameters()).device
814
+ if noise is not None:
815
+ img = noise
816
+ else:
817
+ assert isinstance(shape, (tuple, list))
818
+ img = th.randn(*shape, device=device)
819
+ indices = list(range(self.num_timesteps))[::-1]
820
+
821
+ if progress:
822
+ # Lazy import so that we don't depend on tqdm.
823
+ from tqdm.auto import tqdm
824
+
825
+ indices = tqdm(indices)
826
+
827
+ for i in indices:
828
+ if isinstance(model_kwargs, list):
829
+ # index dependent model kwargs
830
+ # (T-1, ..., 0)
831
+ _kwargs = model_kwargs[i]
832
+ else:
833
+ _kwargs = model_kwargs
834
+
835
+ t = th.tensor([i] * len(img), device=device)
836
+ with th.no_grad():
837
+ out = self.ddim_sample(
838
+ model,
839
+ img,
840
+ t,
841
+ clip_denoised=clip_denoised,
842
+ denoised_fn=denoised_fn,
843
+ cond_fn=cond_fn,
844
+ model_kwargs=_kwargs,
845
+ eta=eta,
846
+ )
847
+ out["t"] = t
848
+ yield out
849
+ img = out["sample"]
850
+
851
+ def _vb_terms_bpd(
852
+ self, model: Model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
853
+ ):
854
+ """
855
+ Get a term for the variational lower-bound.
856
+
857
+ The resulting units are bits (rather than nats, as one might expect).
858
+ This allows for comparison to other papers.
859
+
860
+ :return: a dict with the following keys:
861
+ - 'output': a shape [N] tensor of NLLs or KLs.
862
+ - 'pred_xstart': the x_0 predictions.
863
+ """
864
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
865
+ x_start=x_start, x_t=x_t, t=t
866
+ )
867
+ out = self.p_mean_variance(
868
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
869
+ )
870
+ kl = normal_kl(
871
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
872
+ )
873
+ kl = mean_flat(kl) / np.log(2.0)
874
+
875
+ decoder_nll = -discretized_gaussian_log_likelihood(
876
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
877
+ )
878
+ assert decoder_nll.shape == x_start.shape
879
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
880
+
881
+ # At the first timestep return the decoder NLL,
882
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
883
+ output = th.where((t == 0), decoder_nll, kl)
884
+ return {
885
+ "output": output,
886
+ "pred_xstart": out["pred_xstart"],
887
+ "model_forward": out["model_forward"],
888
+ }
889
+
890
+ def _prior_bpd(self, x_start):
891
+ """
892
+ Get the prior KL term for the variational lower-bound, measured in
893
+ bits-per-dim.
894
+
895
+ This term can't be optimized, as it only depends on the encoder.
896
+
897
+ :param x_start: the [N x C x ...] tensor of inputs.
898
+ :return: a batch of [N] KL values (in bits), one per batch element.
899
+ """
900
+ batch_size = x_start.shape[0]
901
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
902
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
903
+ kl_prior = normal_kl(
904
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
905
+ )
906
+ return mean_flat(kl_prior) / np.log(2.0)
907
+
908
+ def calc_bpd_loop(
909
+ self, model: Model, x_start, clip_denoised=True, model_kwargs=None
910
+ ):
911
+ """
912
+ Compute the entire variational lower-bound, measured in bits-per-dim,
913
+ as well as other related quantities.
914
+
915
+ :param model: the model to evaluate loss on.
916
+ :param x_start: the [N x C x ...] tensor of inputs.
917
+ :param clip_denoised: if True, clip denoised samples.
918
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
919
+ pass to the model. This can be used for conditioning.
920
+
921
+ :return: a dict containing the following keys:
922
+ - total_bpd: the total variational lower-bound, per batch element.
923
+ - prior_bpd: the prior term in the lower-bound.
924
+ - vb: an [N x T] tensor of terms in the lower-bound.
925
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
926
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
927
+ """
928
+ device = x_start.device
929
+ batch_size = x_start.shape[0]
930
+
931
+ vb = []
932
+ xstart_mse = []
933
+ mse = []
934
+ for t in list(range(self.num_timesteps))[::-1]:
935
+ t_batch = th.tensor([t] * batch_size, device=device)
936
+ noise = th.randn_like(x_start)
937
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
938
+ # Calculate VLB term at the current timestep
939
+ with th.no_grad():
940
+ out = self._vb_terms_bpd(
941
+ model,
942
+ x_start=x_start,
943
+ x_t=x_t,
944
+ t=t_batch,
945
+ clip_denoised=clip_denoised,
946
+ model_kwargs=model_kwargs,
947
+ )
948
+ vb.append(out["output"])
949
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
950
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
951
+ mse.append(mean_flat((eps - noise) ** 2))
952
+
953
+ vb = th.stack(vb, dim=1)
954
+ xstart_mse = th.stack(xstart_mse, dim=1)
955
+ mse = th.stack(mse, dim=1)
956
+
957
+ prior_bpd = self._prior_bpd(x_start)
958
+ total_bpd = vb.sum(dim=1) + prior_bpd
959
+ return {
960
+ "total_bpd": total_bpd,
961
+ "prior_bpd": prior_bpd,
962
+ "vb": vb,
963
+ "xstart_mse": xstart_mse,
964
+ "mse": mse,
965
+ }
966
+
967
+
968
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
969
+ """
970
+ Extract values from a 1-D numpy array for a batch of indices.
971
+
972
+ :param arr: the 1-D numpy array.
973
+ :param timesteps: a tensor of indices into the array to extract.
974
+ :param broadcast_shape: a larger shape of K dimensions with the batch
975
+ dimension equal to the length of timesteps.
976
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
977
+ """
978
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
979
+ while len(res.shape) < len(broadcast_shape):
980
+ res = res[..., None]
981
+ return res.expand(broadcast_shape)
982
+
983
+
984
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
985
+ """
986
+ Get a pre-defined beta schedule for the given name.
987
+
988
+ The beta schedule library consists of beta schedules which remain similar
989
+ in the limit of num_diffusion_timesteps.
990
+ Beta schedules may be added, but should not be removed or changed once
991
+ they are committed to maintain backwards compatibility.
992
+ """
993
+ if schedule_name == "linear":
994
+ # Linear schedule from Ho et al, extended to work for any number of
995
+ # diffusion steps.
996
+ scale = 1000 / num_diffusion_timesteps
997
+ beta_start = scale * 0.0001
998
+ beta_end = scale * 0.02
999
+ return np.linspace(
1000
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
1001
+ )
1002
+ elif schedule_name == "cosine":
1003
+ return betas_for_alpha_bar(
1004
+ num_diffusion_timesteps,
1005
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
1006
+ )
1007
+ elif schedule_name == "const0.01":
1008
+ scale = 1000 / num_diffusion_timesteps
1009
+ return np.array([scale * 0.01] * num_diffusion_timesteps, dtype=np.float64)
1010
+ elif schedule_name == "const0.015":
1011
+ scale = 1000 / num_diffusion_timesteps
1012
+ return np.array([scale * 0.015] * num_diffusion_timesteps, dtype=np.float64)
1013
+ elif schedule_name == "const0.008":
1014
+ scale = 1000 / num_diffusion_timesteps
1015
+ return np.array([scale * 0.008] * num_diffusion_timesteps, dtype=np.float64)
1016
+ elif schedule_name == "const0.0065":
1017
+ scale = 1000 / num_diffusion_timesteps
1018
+ return np.array([scale * 0.0065] * num_diffusion_timesteps, dtype=np.float64)
1019
+ elif schedule_name == "const0.0055":
1020
+ scale = 1000 / num_diffusion_timesteps
1021
+ return np.array([scale * 0.0055] * num_diffusion_timesteps, dtype=np.float64)
1022
+ elif schedule_name == "const0.0045":
1023
+ scale = 1000 / num_diffusion_timesteps
1024
+ return np.array([scale * 0.0045] * num_diffusion_timesteps, dtype=np.float64)
1025
+ elif schedule_name == "const0.0035":
1026
+ scale = 1000 / num_diffusion_timesteps
1027
+ return np.array([scale * 0.0035] * num_diffusion_timesteps, dtype=np.float64)
1028
+ elif schedule_name == "const0.0025":
1029
+ scale = 1000 / num_diffusion_timesteps
1030
+ return np.array([scale * 0.0025] * num_diffusion_timesteps, dtype=np.float64)
1031
+ elif schedule_name == "const0.0015":
1032
+ scale = 1000 / num_diffusion_timesteps
1033
+ return np.array([scale * 0.0015] * num_diffusion_timesteps, dtype=np.float64)
1034
+ else:
1035
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1036
+
1037
+
1038
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
1039
+ """
1040
+ Create a beta schedule that discretizes the given alpha_t_bar function,
1041
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
1042
+
1043
+ :param num_diffusion_timesteps: the number of betas to produce.
1044
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
1045
+ produces the cumulative product of (1-beta) up to that
1046
+ part of the diffusion process.
1047
+ :param max_beta: the maximum beta to use; use values lower than 1 to
1048
+ prevent singularities.
1049
+ """
1050
+ betas = []
1051
+ for i in range(num_diffusion_timesteps):
1052
+ t1 = i / num_diffusion_timesteps
1053
+ t2 = (i + 1) / num_diffusion_timesteps
1054
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
1055
+ return np.array(betas)
1056
+
1057
+
1058
+ def normal_kl(mean1, logvar1, mean2, logvar2):
1059
+ """
1060
+ Compute the KL divergence between two gaussians.
1061
+
1062
+ Shapes are automatically broadcasted, so batches can be compared to
1063
+ scalars, among other use cases.
1064
+ """
1065
+ tensor = None
1066
+ for obj in (mean1, logvar1, mean2, logvar2):
1067
+ if isinstance(obj, th.Tensor):
1068
+ tensor = obj
1069
+ break
1070
+ assert tensor is not None, "at least one argument must be a Tensor"
1071
+
1072
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
1073
+ # Tensors, but it does not work for th.exp().
1074
+ logvar1, logvar2 = [
1075
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
1076
+ for x in (logvar1, logvar2)
1077
+ ]
1078
+
1079
+ return 0.5 * (
1080
+ -1.0
1081
+ + logvar2
1082
+ - logvar1
1083
+ + th.exp(logvar1 - logvar2)
1084
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
1085
+ )
1086
+
1087
+
1088
+ def approx_standard_normal_cdf(x):
1089
+ """
1090
+ A fast approximation of the cumulative distribution function of the
1091
+ standard normal.
1092
+ """
1093
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
1094
+
1095
+
1096
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
1097
+ """
1098
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
1099
+ given image.
1100
+
1101
+ :param x: the target images. It is assumed that this was uint8 values,
1102
+ rescaled to the range [-1, 1].
1103
+ :param means: the Gaussian mean Tensor.
1104
+ :param log_scales: the Gaussian log stddev Tensor.
1105
+ :return: a tensor like x of log probabilities (in nats).
1106
+ """
1107
+ assert x.shape == means.shape == log_scales.shape
1108
+ centered_x = x - means
1109
+ inv_stdv = th.exp(-log_scales)
1110
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
1111
+ cdf_plus = approx_standard_normal_cdf(plus_in)
1112
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
1113
+ cdf_min = approx_standard_normal_cdf(min_in)
1114
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
1115
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
1116
+ cdf_delta = cdf_plus - cdf_min
1117
+ log_probs = th.where(
1118
+ x < -0.999,
1119
+ log_cdf_plus,
1120
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
1121
+ )
1122
+ assert log_probs.shape == x.shape
1123
+ return log_probs
1124
+
1125
+
1126
+ class DummyModel(th.nn.Module):
1127
+ def __init__(self, pred):
1128
+ super().__init__()
1129
+ self.pred = pred
1130
+
1131
+ def forward(self, *args, **kwargs):
1132
+ return DummyReturn(pred=self.pred)
1133
+
1134
+
1135
+ class DummyReturn(NamedTuple):
1136
+ pred: th.Tensor
src/visualizr/diffusion/diffusion.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ from torch import tensor
6
+
7
+ from visualizr.diffusion.base import (
8
+ GaussianDiffusionBeatGans,
9
+ GaussianDiffusionBeatGansConfig,
10
+ )
11
+ from visualizr.model import Model
12
+
13
+
14
+ def space_timesteps(num_timesteps, section_counts):
15
+ """
16
+ Create a list of timesteps to use from an original diffusion process,
17
+ given the number of timesteps we want to take from equally-sized portions
18
+ of the original process.
19
+
20
+ For example, if there are 300 timesteps and the section counts are [10,15,20]
21
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
22
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
23
+
24
+ If the stride is a string starting with "ddim", then the fixed striding
25
+ from the DDIM paper is used, and only one section is allowed.
26
+
27
+ :param num_timesteps: the number of diffusion steps in the original
28
+ process to divide up.
29
+ :param section_counts: either a list of numbers, or a string containing
30
+ comma-separated numbers, indicating the step count
31
+ per section. As a special case, use "ddimN" where N
32
+ is a number of steps to use the striding from the
33
+ DDIM paper.
34
+ :return: a set of diffusion steps from the original process to use.
35
+ """
36
+ if isinstance(section_counts, str):
37
+ if section_counts.startswith("ddim"):
38
+ desired_count = int(section_counts[len("ddim") :])
39
+ for i in range(1, num_timesteps):
40
+ if len(range(0, num_timesteps, i)) == desired_count:
41
+ return set(range(0, num_timesteps, i))
42
+ raise ValueError(
43
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
44
+ )
45
+ section_counts = [int(x) for x in section_counts.split(",")]
46
+ size_per = num_timesteps // len(section_counts)
47
+ extra = num_timesteps % len(section_counts)
48
+ start_idx = 0
49
+ all_steps = []
50
+ for i, section_count in enumerate(section_counts):
51
+ size = size_per + (1 if i < extra else 0)
52
+ if size < section_count:
53
+ raise ValueError(
54
+ f"cannot divide section of {size} steps into {section_count}"
55
+ )
56
+ if section_count <= 1:
57
+ frac_stride = 1
58
+ else:
59
+ frac_stride = (size - 1) / (section_count - 1)
60
+ cur_idx = 0.0
61
+ taken_steps = []
62
+ for _ in range(section_count):
63
+ taken_steps.append(start_idx + round(cur_idx))
64
+ cur_idx += frac_stride
65
+ all_steps += taken_steps
66
+ start_idx += size
67
+ return set(all_steps)
68
+
69
+
70
+ @dataclass
71
+ class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
72
+ use_timesteps: Tuple[int] = None
73
+
74
+ def make_sampler(self):
75
+ return SpacedDiffusionBeatGans(self)
76
+
77
+
78
+ class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
79
+ """
80
+ A diffusion process which can skip steps in a base diffusion process.
81
+
82
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
83
+ original diffusion process to retain.
84
+ :param kwargs: the kwargs to create the base diffusion process.
85
+ """
86
+
87
+ def __init__(self, conf: SpacedDiffusionBeatGansConfig):
88
+ self.conf = conf
89
+ self.use_timesteps = set(conf.use_timesteps)
90
+ # how the new t's mapped to the old t's
91
+ self.timestep_map = []
92
+ self.original_num_steps = len(conf.betas)
93
+
94
+ base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
95
+ last_alpha_cumprod = 1.0
96
+ new_betas = []
97
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
98
+ if i in self.use_timesteps:
99
+ # getting the new betas of the new timesteps
100
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
101
+ last_alpha_cumprod = alpha_cumprod
102
+ self.timestep_map.append(i)
103
+ conf.betas = np.array(new_betas)
104
+ super().__init__(conf)
105
+
106
+ def p_mean_variance(self, model: Model, *args, **kwargs):
107
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
108
+
109
+ def training_losses(self, model: Model, *args, **kwargs):
110
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
111
+
112
+ def condition_mean(self, cond_fn, *args, **kwargs):
113
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
114
+
115
+ def condition_score(self, cond_fn, *args, **kwargs):
116
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
117
+
118
+ def _wrap_model(self, model: Model):
119
+ if isinstance(model, _WrappedModel):
120
+ return model
121
+ return _WrappedModel(
122
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
123
+ )
124
+
125
+ def _scale_timesteps(self, t):
126
+ # Scaling is done by the wrapped model.
127
+ return t
128
+
129
+
130
+ class _WrappedModel:
131
+ """converting the supplied t's to the old t's scales."""
132
+
133
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
134
+ self.model = model
135
+ self.timestep_map = timestep_map
136
+ self.rescale_timesteps = rescale_timesteps
137
+ self.original_num_steps = original_num_steps
138
+
139
+ def forward(
140
+ self,
141
+ motion_start,
142
+ motion_direction_start,
143
+ audio_feats,
144
+ face_location,
145
+ face_scale,
146
+ yaw_pitch_roll,
147
+ x_t,
148
+ t,
149
+ control_flag=False,
150
+ ):
151
+ """
152
+ Args:
153
+ t: t's with different ranges
154
+ (can be << T due to smaller eval T)
155
+ need to be converted to the original t's
156
+ t_cond: the same as t but can be of different values
157
+ """
158
+ map_tensor = tensor(self.timestep_map, device=t.device, dtype=t.dtype)
159
+
160
+ def do(t):
161
+ new_ts = map_tensor[t]
162
+ if self.rescale_timesteps:
163
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
164
+ return new_ts
165
+
166
+ return self.model(
167
+ motion_start,
168
+ motion_direction_start,
169
+ audio_feats,
170
+ face_location,
171
+ face_scale,
172
+ yaw_pitch_roll,
173
+ x_t,
174
+ do(t),
175
+ control_flag=control_flag,
176
+ )
177
+
178
+ def __getattr__(self, name):
179
+ # allow for calling the model's methods
180
+ if hasattr(self.model, name):
181
+ func = getattr(self.model, name)
182
+ return func
183
+ raise AttributeError(name)
src/visualizr/diffusion/resample.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+
6
+
7
+ def create_named_schedule_sampler(name, diffusion):
8
+ """
9
+ Create a ScheduleSampler from a library of pre-defined samplers.
10
+
11
+ :param name: The name of the sampler.
12
+ :param diffusion: The diffusion object to sample for.
13
+ """
14
+ if name == "uniform":
15
+ return UniformSampler(diffusion)
16
+ else:
17
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
18
+
19
+
20
+ class ScheduleSampler(ABC):
21
+ """
22
+ A distribution over timesteps in the diffusion process, intended to reduce
23
+ variance of the goal.
24
+
25
+ By default, samplers perform unbiased importance sampling, in which the
26
+ objective's mean is unchanged.
27
+ However, subclasses may override sample() to change how the resampled
28
+ terms are reweighted, allowing for actual changes in the goal.
29
+ """
30
+
31
+ @abstractmethod
32
+ def weights(self):
33
+ """
34
+ Get a numpy array of weights, one per diffusion step.
35
+
36
+ The weights needn't be normalized but must be positive.
37
+ """
38
+
39
+ def sample(self, batch_size, device):
40
+ """
41
+ Importance-sample timesteps for a batch.
42
+
43
+ :param batch_size: The number of timesteps.
44
+ :param device: The torch device to save to.
45
+ :return: A tuple (timesteps, weights):
46
+ - timesteps: a tensor of timestep indices.
47
+ - weights: a tensor of weights to scale the resulting losses.
48
+ """
49
+ w = self.weights()
50
+ p = w / np.sum(w)
51
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
52
+ indices = th.from_numpy(indices_np).long().to(device)
53
+ weights_np = 1 / (len(p) * p[indices_np])
54
+ weights = th.from_numpy(weights_np).float().to(device)
55
+ return indices, weights
56
+
57
+
58
+ class UniformSampler(ScheduleSampler):
59
+ def __init__(self, num_timesteps):
60
+ self._weights = np.ones([num_timesteps])
61
+
62
+ def weights(self):
63
+ return self._weights
src/visualizr/dist_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from torch import distributed
4
+
5
+
6
+ def barrier():
7
+ if distributed.is_initialized():
8
+ distributed.barrier()
9
+ else:
10
+ pass
11
+
12
+
13
+ def broadcast(data, src):
14
+ if distributed.is_initialized():
15
+ distributed.broadcast(data, src)
16
+ else:
17
+ pass
18
+
19
+
20
+ def all_gather(data: List, src):
21
+ if distributed.is_initialized():
22
+ distributed.all_gather(data, src)
23
+ else:
24
+ data[0] = src
25
+
26
+
27
+ def get_rank():
28
+ if distributed.is_initialized():
29
+ return distributed.get_rank()
30
+ else:
31
+ return 0
32
+
33
+
34
+ def get_world_size():
35
+ if distributed.is_initialized():
36
+ return distributed.get_world_size()
37
+ else:
38
+ return 1
39
+
40
+
41
+ def chunk_size(size, rank, world_size):
42
+ extra = rank < size % world_size
43
+ return size // world_size + extra
src/visualizr/experiment.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from pytorch_lightning import loggers as pl_loggers
8
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
9
+ from torch.cuda import amp
10
+ from torch.optim.optimizer import Optimizer
11
+ from torch.utils.data.dataset import TensorDataset
12
+
13
+ from visualizr import logger
14
+ from visualizr.choices import OptimizerType, TrainMode
15
+ from visualizr.config import TrainConfig
16
+ from visualizr.dist_utils import get_world_size
17
+ from visualizr.model.seq2seq import DiffusionPredictor
18
+ from visualizr.renderer import render_condition
19
+
20
+
21
+ class LitModel(pl.LightningModule):
22
+ def __init__(self, conf: TrainConfig):
23
+ super().__init__()
24
+ assert conf.train_mode != TrainMode.manipulate
25
+ if conf.seed is not None:
26
+ pl.seed_everything(conf.seed)
27
+
28
+ self.save_hyperparameters(conf.as_dict_jsonable())
29
+
30
+ self.conf = conf
31
+
32
+ self.model = DiffusionPredictor(conf)
33
+
34
+ self.ema_model = copy.deepcopy(self.model)
35
+ self.ema_model.requires_grad_(False)
36
+ self.ema_model.eval()
37
+
38
+ self.sampler = conf.make_diffusion_conf().make_sampler()
39
+ self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
40
+
41
+ # this is shared for both model and latent
42
+ self.T_sampler = conf.make_T_sampler()
43
+
44
+ if conf.train_mode.use_latent_net():
45
+ self.latent_sampler = conf.make_latent_diffusion_conf().make_sampler()
46
+ self.eval_latent_sampler = (
47
+ conf.make_latent_eval_diffusion_conf().make_sampler()
48
+ )
49
+ else:
50
+ self.latent_sampler = None
51
+ self.eval_latent_sampler = None
52
+
53
+ # initial variables for consistent sampling
54
+ self.register_buffer(
55
+ "x_T", torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)
56
+ )
57
+
58
+ def render(
59
+ self,
60
+ start,
61
+ motion_direction_start,
62
+ audio_driven,
63
+ face_location,
64
+ face_scale,
65
+ ypr_info,
66
+ noisyT,
67
+ step_T,
68
+ control_flag,
69
+ ):
70
+ if step_T is None:
71
+ sampler = self.eval_sampler
72
+ else:
73
+ sampler = self.conf._make_diffusion_conf(step_T).make_sampler()
74
+
75
+ pred_img = render_condition(
76
+ self.conf,
77
+ self.ema_model,
78
+ sampler,
79
+ start,
80
+ motion_direction_start,
81
+ audio_driven,
82
+ face_location,
83
+ face_scale,
84
+ ypr_info,
85
+ noisyT,
86
+ control_flag,
87
+ )
88
+ return pred_img
89
+
90
+ def forward(self, noise=None, x_start=None, ema_model: bool = False):
91
+ with amp.autocast(False):
92
+ if not self.disable_ema:
93
+ model = self.ema_model
94
+ else:
95
+ model = self.model
96
+ gen = self.eval_sampler.sample(model=model, noise=noise, x_start=x_start)
97
+ return gen
98
+
99
+ def setup(self, stage=None) -> None:
100
+ """
101
+ make datasets & seeding each worker separately
102
+ """
103
+ ##############################################
104
+ # NEED TO SET THE SEED SEPARATELY HERE
105
+ if self.conf.seed is not None:
106
+ seed = self.conf.seed * get_world_size() + self.global_rank
107
+ np.random.seed(seed)
108
+ torch.manual_seed(seed)
109
+ torch.cuda.manual_seed(seed)
110
+ logger.info("local seed:", seed)
111
+ ##############################################
112
+
113
+ self.train_data = self.conf.make_dataset()
114
+ logger.info("train data:", len(self.train_data))
115
+ self.val_data = self.train_data
116
+ logger.info("val data:", len(self.val_data))
117
+
118
+ def _train_dataloader(self, drop_last=True):
119
+ """
120
+ really make the dataloader
121
+ """
122
+ # make sure to use the fraction of batch size
123
+ # the batch size is global!
124
+ conf = self.conf.clone()
125
+ conf.batch_size = self.batch_size
126
+
127
+ dataloader = conf.make_loader(
128
+ self.train_data, shuffle=True, drop_last=drop_last
129
+ )
130
+ return dataloader
131
+
132
+ def train_dataloader(self):
133
+ """
134
+ return the dataloader, if diffusion mode => return image dataset
135
+ if latent mode => return the inferred latent dataset
136
+ """
137
+ logger.info("on train dataloader start ...")
138
+ if self.conf.train_mode.require_dataset_infer():
139
+ if self.conds is None:
140
+ # usually we load self.conds from a file
141
+ # so we do not need to do this again!
142
+ self.conds = self.infer_whole_dataset()
143
+ # need to use float32! unless the mean & std will be off!
144
+ # (1, c)
145
+ self.conds_mean.data = self.conds.float().mean(dim=0, keepdim=True)
146
+ self.conds_std.data = self.conds.float().std(dim=0, keepdim=True)
147
+ logger.info("mean:", self.conds_mean.mean(), "std:", self.conds_std.mean())
148
+
149
+ # return the dataset with pre-calculated conds
150
+ conf = self.conf.clone()
151
+ conf.batch_size = self.batch_size
152
+ data = TensorDataset(self.conds)
153
+ return conf.make_loader(data, shuffle=True)
154
+ else:
155
+ return self._train_dataloader()
156
+
157
+ @property
158
+ def batch_size(self):
159
+ """
160
+ local batch size for each worker
161
+ """
162
+ ws = get_world_size()
163
+ assert self.conf.batch_size % ws == 0
164
+ return self.conf.batch_size // ws
165
+
166
+ @property
167
+ def num_samples(self):
168
+ """
169
+ (global) batch size * iterations
170
+ """
171
+ # batch size here is global!
172
+ # global_step already takes into account the accum batches
173
+ return self.global_step * self.conf.batch_size_effective
174
+
175
+ def is_last_accum(self, batch_idx):
176
+ """
177
+ is it the last gradient accumulation loop?
178
+ used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
179
+ """
180
+ return (batch_idx + 1) % self.conf.accum_batches == 0
181
+
182
+ def training_step(self, batch, batch_idx):
183
+ """
184
+ given an input, calculate the loss function
185
+ no optimization at this stage.
186
+ """
187
+ with amp.autocast(False):
188
+ motion_start = batch["motion_start"] # torch.Size([B, 512])
189
+ motion_direction = batch["motion_direction"] # torch.Size([B, 125, 20])
190
+ audio_feats = batch["audio_feats"].float() # torch.Size([B, 25, 250, 1024])
191
+ face_location = batch["face_location"].float() # torch.Size([B, 125])
192
+ face_scale = batch["face_scale"].float() # torch.Size([B, 125, 1])
193
+ yaw_pitch_roll = batch["yaw_pitch_roll"].float() # torch.Size([B, 125, 3])
194
+ motion_direction_start = batch[
195
+ "motion_direction_start"
196
+ ].float() # torch.Size([B, 20])
197
+
198
+ # import pdb; pdb.set_trace()
199
+ if self.conf.train_mode == TrainMode.diffusion:
200
+ """
201
+ main training mode!!!
202
+ """
203
+ # with numpy seed we have the problem that the sample t's are related!
204
+ t, weight = self.T_sampler.sample(
205
+ len(motion_start), motion_start.device
206
+ )
207
+ losses = self.sampler.training_losses(
208
+ model=self.model,
209
+ motion_direction_start=motion_direction_start,
210
+ motion_target=motion_direction,
211
+ motion_start=motion_start,
212
+ audio_feats=audio_feats,
213
+ face_location=face_location,
214
+ face_scale=face_scale,
215
+ yaw_pitch_roll=yaw_pitch_roll,
216
+ t=t,
217
+ )
218
+ else:
219
+ raise NotImplementedError()
220
+
221
+ loss = losses["loss"].mean()
222
+ # divide by accum batches to make the accumulated gradient exact!
223
+ for key in losses.keys():
224
+ losses[key] = self.all_gather(losses[key]).mean()
225
+
226
+ if self.global_rank == 0:
227
+ self.logger.experiment.add_scalar(
228
+ "loss", losses["loss"], self.num_samples
229
+ )
230
+ for key in losses:
231
+ self.logger.experiment.add_scalar(
232
+ f"loss/{key}", losses[key], self.num_samples
233
+ )
234
+
235
+ return {"loss": loss}
236
+
237
+ def on_train_batch_end(
238
+ self, outputs, batch, batch_idx: int, dataloader_idx: int
239
+ ) -> None:
240
+ """
241
+ after each training step ...
242
+ """
243
+ if self.is_last_accum(batch_idx):
244
+ if self.conf.train_mode == TrainMode.latent_diffusion:
245
+ # it trains only the latent hence change only the latent
246
+ ema(
247
+ self.model.latent_net,
248
+ self.ema_model.latent_net,
249
+ self.conf.ema_decay,
250
+ )
251
+ else:
252
+ ema(self.model, self.ema_model, self.conf.ema_decay)
253
+
254
+ def on_before_optimizer_step(
255
+ self, optimizer: Optimizer, optimizer_idx: int
256
+ ) -> None:
257
+ # fix the fp16 + clip grad norm problem with pytorch lightinng
258
+ # this is the currently correct way to do it
259
+ if self.conf.grad_clip > 0:
260
+ # from trainer.params_grads import grads_norm, iter_opt_params
261
+ params = [p for group in optimizer.param_groups for p in group["params"]]
262
+ torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip)
263
+
264
+ def configure_optimizers(self):
265
+ out = {}
266
+ if self.conf.optimizer == OptimizerType.adam:
267
+ optim = torch.optim.Adam(
268
+ self.model.parameters(),
269
+ lr=self.conf.lr,
270
+ weight_decay=self.conf.weight_decay,
271
+ )
272
+ elif self.conf.optimizer == OptimizerType.adamw:
273
+ optim = torch.optim.AdamW(
274
+ self.model.parameters(),
275
+ lr=self.conf.lr,
276
+ weight_decay=self.conf.weight_decay,
277
+ )
278
+ else:
279
+ raise NotImplementedError()
280
+ out["optimizer"] = optim
281
+ if self.conf.warmup > 0:
282
+ sched = torch.optim.lr_scheduler.LambdaLR(
283
+ optim, lr_lambda=WarmupLR(self.conf.warmup)
284
+ )
285
+ out["lr_scheduler"] = {
286
+ "scheduler": sched,
287
+ "interval": "step",
288
+ }
289
+ return out
290
+
291
+ def split_tensor(self, x):
292
+ """
293
+ extract the tensor for a corresponding "worker" in the batch dimension
294
+
295
+ Args:
296
+ x: (n, c)
297
+
298
+ Returns: x: (n_local, c)
299
+ """
300
+ n = len(x)
301
+ rank = self.global_rank
302
+ world_size = get_world_size()
303
+ # print(f'rank: {rank}/{world_size}')
304
+ per_rank = n // world_size
305
+ return x[rank * per_rank : (rank + 1) * per_rank]
306
+
307
+
308
+ def ema(source, target, decay):
309
+ source_dict = source.state_dict()
310
+ target_dict = target.state_dict()
311
+ for key in source_dict.keys():
312
+ target_dict[key].data.copy_(
313
+ target_dict[key].data * decay + source_dict[key].data * (1 - decay)
314
+ )
315
+
316
+
317
+ class WarmupLR:
318
+ def __init__(self, warmup) -> None:
319
+ self.warmup = warmup
320
+
321
+ def __call__(self, step):
322
+ return min(step, self.warmup) / self.warmup
323
+
324
+
325
+ def is_time(num_samples, every, step_size):
326
+ closest = (num_samples // every) * every
327
+ return num_samples - closest < step_size
328
+
329
+
330
+ def train(conf: TrainConfig, gpus, nodes=1, mode: str = "train"):
331
+ logger.info("conf:", conf.name)
332
+ model = LitModel(conf)
333
+
334
+ if not os.path.exists(conf.logdir):
335
+ os.makedirs(conf.logdir)
336
+ checkpoint = ModelCheckpoint(
337
+ dirpath=f"{conf.logdir}", save_last=True, save_top_k=-1, every_n_epochs=10
338
+ )
339
+ checkpoint_path = f"{conf.logdir}/last.ckpt"
340
+ logger.info("ckpt path:", checkpoint_path)
341
+ if os.path.exists(checkpoint_path):
342
+ resume = checkpoint_path
343
+ logger.info("resume!")
344
+ else:
345
+ if conf.continue_from is not None:
346
+ # continue from a checkpoint
347
+ resume = conf.continue_from.pathcd
348
+ else:
349
+ resume = None
350
+
351
+ tb_logger = pl_loggers.TensorBoardLogger(
352
+ save_dir=conf.logdir, name=None, version=""
353
+ )
354
+
355
+ # from pytorch_lightning.
356
+
357
+ plugins = []
358
+ if len(gpus) == 1 and nodes == 1:
359
+ accelerator = None
360
+ else:
361
+ accelerator = "ddp"
362
+ from pytorch_lightning.plugins import DDPPlugin
363
+
364
+ # important for working with gradient checkpoint
365
+ plugins.append(DDPPlugin(find_unused_parameters=True))
366
+
367
+ trainer = pl.Trainer(
368
+ max_steps=conf.total_samples // conf.batch_size_effective,
369
+ resume_from_checkpoint=resume,
370
+ gpus=gpus,
371
+ num_nodes=nodes,
372
+ accelerator=accelerator,
373
+ precision=16 if conf.fp16 else 32,
374
+ callbacks=[
375
+ checkpoint,
376
+ LearningRateMonitor(),
377
+ ],
378
+ # clip in the model instead
379
+ # gradient_clip_val=conf.grad_clip,
380
+ replace_sampler_ddp=True,
381
+ logger=tb_logger,
382
+ accumulate_grad_batches=conf.accum_batches,
383
+ plugins=plugins,
384
+ )
385
+
386
+ trainer.fit(model)
src/visualizr/face_sr/face_enhancer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+ from gfpgan import GFPGANer
6
+ from tqdm import tqdm
7
+
8
+ from visualizr import logger
9
+ from visualizr.face_sr.videoio import load_video_to_cv2
10
+
11
+
12
+ class GeneratorWithLen(object):
13
+ """From https://stackoverflow.com/a/7460929"""
14
+
15
+ def __init__(self, gen, length):
16
+ self.gen = gen
17
+ self.length = length
18
+
19
+ def __len__(self):
20
+ return self.length
21
+
22
+ def __iter__(self):
23
+ return self.gen
24
+
25
+
26
+ def enhancer_list(images, method="gfpgan", bg_upsampler="realesrgan"):
27
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
28
+ return list(gen)
29
+
30
+
31
+ def enhancer_generator_with_len(images, method="gfpgan", bg_upsampler="realesrgan"):
32
+ """Provide a generator with a __len__ method so that it can passed to functions that
33
+ call len()"""
34
+
35
+ if os.path.isfile(images): # handle video to images
36
+ images = load_video_to_cv2(images)
37
+
38
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
39
+ gen_with_len = GeneratorWithLen(gen, len(images))
40
+ return gen_with_len
41
+
42
+
43
+ def enhancer_generator_no_len(images, method="gfpgan", bg_upsampler="realesrgan"):
44
+ """Provide a generator function so that all of the enhanced images don't need
45
+ to be stored in memory at the same time. This can save tons of RAM compared to
46
+ the enhancer function."""
47
+ if method not in ["gfpgan", "RestoreFormer", "codeformer"]:
48
+ raise ValueError(f"Wrong model version {method}.")
49
+ logger.info("face enhancer....")
50
+ if not isinstance(images, list) and os.path.isfile(
51
+ images
52
+ ): # handle video to images
53
+ images = load_video_to_cv2(images)
54
+
55
+ # ------------------------ set up GFPGAN restorer ------------------------
56
+ match method:
57
+ case "gfpgan":
58
+ arch = "clean"
59
+ channel_multiplier = 2
60
+ model_name = "GFPGANv1.4"
61
+ url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
62
+ case "RestoreFormer":
63
+ arch = "RestoreFormer"
64
+ channel_multiplier = 2
65
+ model_name = "RestoreFormer"
66
+ url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
67
+ case "codeformer":
68
+ arch = "CodeFormer"
69
+ channel_multiplier = 2
70
+ model_name = "CodeFormer"
71
+ url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
72
+ # ------------------------ set up background upsampler ------------------------
73
+ if bg_upsampler == "realesrgan":
74
+ if not torch.cuda.is_available(): # CPU
75
+ import warnings
76
+
77
+ warnings.warn(
78
+ "The unoptimized RealESRGAN is slow on CPU. We do not use it. "
79
+ "If you really want to use it, please modify the corresponding codes."
80
+ )
81
+ bg_upsampler = None
82
+ else:
83
+ from basicsr.archs.rrdbnet_arch import RRDBNet
84
+ from realesrgan import RealESRGANer
85
+
86
+ model = RRDBNet(
87
+ num_in_ch=3,
88
+ num_out_ch=3,
89
+ num_feat=64,
90
+ num_block=23,
91
+ num_grow_ch=32,
92
+ scale=2,
93
+ )
94
+ bg_upsampler = RealESRGANer(
95
+ scale=2,
96
+ model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
97
+ model=model,
98
+ tile=400,
99
+ tile_pad=10,
100
+ pre_pad=0,
101
+ half=True,
102
+ ) # need to set False in CPU mode
103
+ else:
104
+ bg_upsampler = None
105
+
106
+ # determine model paths
107
+ model_path = os.path.join("gfpgan/weights", model_name + ".pth")
108
+
109
+ if not os.path.isfile(model_path):
110
+ model_path = os.path.join("checkpoints", model_name + ".pth")
111
+
112
+ if not os.path.isfile(model_path):
113
+ # download pre-trained models from url
114
+ model_path = url
115
+
116
+ restorer = GFPGANer(
117
+ model_path=model_path,
118
+ upscale=2,
119
+ arch=arch,
120
+ channel_multiplier=channel_multiplier,
121
+ bg_upsampler=bg_upsampler,
122
+ )
123
+
124
+ # ------------------------ restore ------------------------
125
+ for idx in tqdm(range(len(images)), "Face Enhancer:"):
126
+ img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
127
+
128
+ # restore faces and background if necessary
129
+ cropped_faces, restored_faces, r_img = restorer.enhance(
130
+ img, has_aligned=False, only_center_face=False, paste_back=True
131
+ )
132
+
133
+ r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
134
+ yield r_img
src/visualizr/face_sr/videoio.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def load_video_to_cv2(input_path: str) -> list[np.ndarray]:
6
+ video_stream = cv2.VideoCapture(input_path)
7
+ full_frames = []
8
+ while 1:
9
+ still_reading, frame = video_stream.read()
10
+ if not still_reading:
11
+ video_stream.release()
12
+ break
13
+ full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
14
+ return full_frames
src/visualizr/gui.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio import (
2
+ Accordion,
3
+ Audio,
4
+ Blocks,
5
+ Button,
6
+ Checkbox,
7
+ Column,
8
+ Dropdown,
9
+ Image,
10
+ Markdown,
11
+ Number,
12
+ Row,
13
+ Slider,
14
+ Video,
15
+ )
16
+
17
+ from visualizr.settings import DefaultValues
18
+ from visualizr.utils import generate_video
19
+
20
+
21
+ def app_block() -> Blocks:
22
+ """Create the Gradio interface for the voice generation web application."""
23
+ with Blocks() as app:
24
+ Markdown(value="# AniTalker")
25
+ with Row():
26
+ with Column():
27
+ uploaded_img: Image = Image(type="filepath", label="Reference Image")
28
+ uploaded_audio = Audio(
29
+ type="filepath", label="Input Audio", show_download_button=True
30
+ )
31
+ with Column():
32
+ output_video_256 = Video(label="Generated Video (256)")
33
+ output_video_512 = Video(label="Generated Video (512)")
34
+ output_message = Markdown()
35
+
36
+ generate_button = Button(value="Generate Video")
37
+
38
+ with Accordion(label="Configuration"):
39
+ infer_type = Dropdown(
40
+ label="Inference Type",
41
+ choices=[
42
+ "mfcc_full_control",
43
+ "mfcc_pose_only",
44
+ "hubert_pose_only",
45
+ "hubert_audio_only",
46
+ "hubert_full_control",
47
+ ],
48
+ value="hubert_audio_only",
49
+ )
50
+ face_sr = Checkbox(label="Enable Face Super-Resolution (512*512)")
51
+ seed = Number(label="Seed", value=DefaultValues().seed)
52
+ pose_yaw = Slider(
53
+ label="pose_yaw",
54
+ minimum=-1,
55
+ maximum=1,
56
+ value=DefaultValues().pose_yaw,
57
+ )
58
+ pose_pitch = Slider(
59
+ label="pose_pitch",
60
+ minimum=-1,
61
+ maximum=1,
62
+ value=DefaultValues().pose_pitch,
63
+ )
64
+ pose_roll = Slider(
65
+ label="pose_roll",
66
+ minimum=-1,
67
+ maximum=1,
68
+ value=DefaultValues().pose_roll,
69
+ )
70
+ face_location = Slider(
71
+ label="face_location", maximum=1, value=DefaultValues().face_location
72
+ )
73
+ face_scale = Slider(
74
+ label="face_scale", maximum=1, value=DefaultValues().face_scale
75
+ )
76
+ step_t = Slider(
77
+ label="step_T", minimum=1, step=1, value=DefaultValues().step_T
78
+ )
79
+
80
+ generate_button.click(
81
+ fn=generate_video,
82
+ inputs=[
83
+ uploaded_img,
84
+ uploaded_audio,
85
+ infer_type,
86
+ pose_yaw,
87
+ pose_pitch,
88
+ pose_roll,
89
+ face_location,
90
+ face_scale,
91
+ step_t,
92
+ face_sr,
93
+ seed,
94
+ ],
95
+ outputs=[output_video_256, output_video_512, output_message],
96
+ )
97
+ return app
src/visualizr/model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from visualizr.model.unet import BeatGANsUNetConfig, BeatGANsUNetModel
4
+ from visualizr.model.unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
5
+
6
+ Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
7
+ ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
src/visualizr/model/base.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class BaseModule(torch.nn.Module):
6
+ def __init__(self):
7
+ super(BaseModule, self).__init__()
8
+
9
+ @property
10
+ def nparams(self):
11
+ """
12
+ Returns number of trainable parameters of the module.
13
+ """
14
+ num_params = 0
15
+ for name, param in self.named_parameters():
16
+ if param.requires_grad:
17
+ num_params += np.prod(param.detach().cpu().numpy().shape)
18
+ return num_params
19
+
20
+ def relocate_input(self, x: list):
21
+ """
22
+ Relocates provided tensors to the same device set for the module.
23
+ """
24
+ device = next(self.parameters()).device
25
+ for i in range(len(x)):
26
+ if isinstance(x[i], torch.Tensor) and x[i].device != device:
27
+ x[i] = x[i].to(device)
28
+ return x
src/visualizr/model/blocks.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from numbers import Number
6
+
7
+ import numpy as np
8
+ import torch as th
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from visualizr.config_base import BaseConfig
13
+ from visualizr.model.nn import (
14
+ avg_pool_nd,
15
+ conv_nd,
16
+ linear,
17
+ normalization,
18
+ torch_checkpoint,
19
+ zero_module,
20
+ )
21
+
22
+
23
+ class ScaleAt(Enum):
24
+ after_norm = "afternorm"
25
+
26
+
27
+ class TimestepBlock(nn.Module):
28
+ """
29
+ Any module where forward() takes timestep embeddings as a second argument.
30
+ """
31
+
32
+ @abstractmethod
33
+ def forward(self, x, emb=None, cond=None, lateral=None):
34
+ """
35
+ Apply the module to `x` given `emb` timestep embeddings.
36
+ """
37
+
38
+
39
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
40
+ """
41
+ A sequential module that passes timestep embeddings to the children that
42
+ support it as an extra input.
43
+ """
44
+
45
+ def forward(self, x, emb=None, cond=None, lateral=None):
46
+ for layer in self:
47
+ if isinstance(layer, TimestepBlock):
48
+ x = layer(x, emb=emb, cond=cond, lateral=lateral)
49
+ else:
50
+ x = layer(x)
51
+ return x
52
+
53
+
54
+ @dataclass
55
+ class ResBlockConfig(BaseConfig):
56
+ channels: int
57
+ emb_channels: int
58
+ dropout: float
59
+ out_channels: int = None
60
+ # condition the resblock with time (and encoder's output)
61
+ use_condition: bool = True
62
+ # whether to use 3x3 conv for skip path when the channels aren't matched
63
+ use_conv: bool = False
64
+ # dimension of conv (always 2 = 2d)
65
+ dims: int = 2
66
+ # gradient checkpoint
67
+ use_checkpoint: bool = False
68
+ up: bool = False
69
+ down: bool = False
70
+ # whether to condition with both time & encoder's output
71
+ two_cond: bool = False
72
+ # number of encoders' output channels
73
+ cond_emb_channels: int = None
74
+ # suggest: False
75
+ has_lateral: bool = False
76
+ lateral_channels: int = None
77
+ # whether to init the convolution with zero weights
78
+ # this is default from BeatGANs and seems to help learning
79
+ use_zero_module: bool = True
80
+
81
+ def __post_init__(self):
82
+ self.out_channels = self.out_channels or self.channels
83
+ self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
84
+
85
+ def make_model(self):
86
+ return ResBlock(self)
87
+
88
+
89
+ class ResBlock(TimestepBlock):
90
+ """
91
+ A residual block that can optionally change the number of channels.
92
+
93
+ total layers:
94
+ in_layers
95
+ - norm
96
+ - act
97
+ - conv
98
+ out_layers
99
+ - norm
100
+ - (modulation)
101
+ - act
102
+ - conv
103
+ """
104
+
105
+ def __init__(self, conf: ResBlockConfig):
106
+ super().__init__()
107
+ self.conf = conf
108
+
109
+ #############################
110
+ # IN LAYERS
111
+ #############################
112
+ assert conf.lateral_channels is None
113
+ layers = [
114
+ normalization(conf.channels),
115
+ nn.SiLU(),
116
+ conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1),
117
+ ]
118
+ self.in_layers = nn.Sequential(*layers)
119
+
120
+ self.updown = conf.up or conf.down
121
+
122
+ if conf.up:
123
+ self.h_upd = Upsample(conf.channels, False, conf.dims)
124
+ self.x_upd = Upsample(conf.channels, False, conf.dims)
125
+ elif conf.down:
126
+ self.h_upd = Downsample(conf.channels, False, conf.dims)
127
+ self.x_upd = Downsample(conf.channels, False, conf.dims)
128
+ else:
129
+ self.h_upd = self.x_upd = nn.Identity()
130
+
131
+ #############################
132
+ # OUT LAYERS CONDITIONS
133
+ #############################
134
+ if conf.use_condition:
135
+ # condition layers for the out_layers
136
+ self.emb_layers = nn.Sequential(
137
+ nn.SiLU(),
138
+ linear(conf.emb_channels, 2 * conf.out_channels),
139
+ )
140
+
141
+ if conf.two_cond:
142
+ self.cond_emb_layers = nn.Sequential(
143
+ nn.SiLU(),
144
+ linear(conf.cond_emb_channels, conf.out_channels),
145
+ )
146
+ #############################
147
+ # OUT LAYERS (ignored when there is no condition)
148
+ #############################
149
+ # original version
150
+ conv = conv_nd(
151
+ conf.dims, conf.out_channels, conf.out_channels, 3, padding=1
152
+ )
153
+ if conf.use_zero_module:
154
+ # zere out the weights
155
+ # it seems to help training
156
+ conv = zero_module(conv)
157
+
158
+ # construct the layers
159
+ # - norm
160
+ # - (modulation)
161
+ # - act
162
+ # - dropout
163
+ # - conv
164
+ layers = []
165
+ layers += [
166
+ normalization(conf.out_channels),
167
+ nn.SiLU(),
168
+ nn.Dropout(p=conf.dropout),
169
+ conv,
170
+ ]
171
+ self.out_layers = nn.Sequential(*layers)
172
+
173
+ #############################
174
+ # SKIP LAYERS
175
+ #############################
176
+ if conf.out_channels == conf.channels:
177
+ # cannot be used with gatedconv, also gatedconv is alsways used as the first block
178
+ self.skip_connection = nn.Identity()
179
+ else:
180
+ if conf.use_conv:
181
+ kernel_size = 3
182
+ padding = 1
183
+ else:
184
+ kernel_size = 1
185
+ padding = 0
186
+
187
+ self.skip_connection = conv_nd(
188
+ conf.dims,
189
+ conf.channels,
190
+ conf.out_channels,
191
+ kernel_size,
192
+ padding=padding,
193
+ )
194
+
195
+ def forward(self, x, emb=None, cond=None, lateral=None):
196
+ """
197
+ Apply the block to a Tensor, conditioned on a timestep embedding.
198
+
199
+ Args:
200
+ x: input
201
+ lateral: lateral connection from the encoder
202
+ """
203
+ return torch_checkpoint(
204
+ self._forward, (x, emb, cond, lateral), self.conf.use_checkpoint
205
+ )
206
+
207
+ def _forward(
208
+ self,
209
+ x,
210
+ emb=None,
211
+ cond=None,
212
+ lateral=None,
213
+ ):
214
+ """
215
+ Args:
216
+ lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
217
+ """
218
+ if self.conf.has_lateral:
219
+ # lateral may be supplied even if it doesn't require
220
+ # the model will take the lateral only if "has_lateral"
221
+ assert lateral is not None
222
+ x = th.cat([x, lateral], dim=1)
223
+
224
+ if self.updown:
225
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
226
+ h = in_rest(x)
227
+ h = self.h_upd(h)
228
+ x = self.x_upd(x)
229
+ h = in_conv(h)
230
+ else:
231
+ h = self.in_layers(x)
232
+
233
+ if self.conf.use_condition:
234
+ # it's possible that the network may not receieve the time emb
235
+ # this happens with autoenc and setting the time_at
236
+ if emb is not None:
237
+ emb_out = self.emb_layers(emb).type(h.dtype)
238
+ else:
239
+ emb_out = None
240
+
241
+ if self.conf.two_cond:
242
+ # it's possible that the network is two_cond
243
+ # but it doesn't get the second condition
244
+ # in which case, we ignore the second condition
245
+ # and treat as if the network has one condition
246
+ if cond is None:
247
+ cond_out = None
248
+ else:
249
+ cond_out = self.cond_emb_layers(cond).type(h.dtype)
250
+
251
+ if cond_out is not None:
252
+ while len(cond_out.shape) < len(h.shape):
253
+ cond_out = cond_out[..., None]
254
+ else:
255
+ cond_out = None
256
+
257
+ # this is the new refactored code
258
+ h = apply_conditions(
259
+ h=h,
260
+ emb=emb_out,
261
+ cond=cond_out,
262
+ layers=self.out_layers,
263
+ scale_bias=1,
264
+ in_channels=self.conf.out_channels,
265
+ up_down_layer=None,
266
+ )
267
+
268
+ return self.skip_connection(x) + h
269
+
270
+
271
+ def apply_conditions(
272
+ h,
273
+ emb=None,
274
+ cond=None,
275
+ layers: nn.Sequential = None,
276
+ scale_bias: float = 1,
277
+ in_channels: int = 512,
278
+ up_down_layer: nn.Module = None,
279
+ ):
280
+ """
281
+ apply conditions on the feature maps
282
+
283
+ Args:
284
+ emb: time conditional (ready to scale + shift)
285
+ cond: encoder's conditional (read to scale + shift)
286
+ """
287
+ two_cond = emb is not None and cond is not None
288
+
289
+ if emb is not None:
290
+ # adjusting shapes
291
+ while len(emb.shape) < len(h.shape):
292
+ emb = emb[..., None]
293
+
294
+ if two_cond:
295
+ # adjusting shapes
296
+ while len(cond.shape) < len(h.shape):
297
+ cond = cond[..., None]
298
+ # time first
299
+ scale_shifts = [emb, cond]
300
+ else:
301
+ # "cond" is not used with single cond mode
302
+ scale_shifts = [emb]
303
+
304
+ # support scale, shift or shift only
305
+ for i, each in enumerate(scale_shifts):
306
+ if each is None:
307
+ # special case: the condition is not provided
308
+ a = None
309
+ b = None
310
+ else:
311
+ if each.shape[1] == in_channels * 2:
312
+ a, b = th.chunk(each, 2, dim=1)
313
+ else:
314
+ a = each
315
+ b = None
316
+ scale_shifts[i] = (a, b)
317
+
318
+ # condition scale bias could be a list
319
+ if isinstance(scale_bias, Number):
320
+ biases = [scale_bias] * len(scale_shifts)
321
+ else:
322
+ # a list
323
+ biases = scale_bias
324
+
325
+ # default, the scale & shift are applied after the group norm but BEFORE SiLU
326
+ pre_layers, post_layers = layers[0], layers[1:]
327
+
328
+ # spilt the post layer to be able to scale up or down before conv
329
+ # post layers will contain only the conv
330
+ mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
331
+
332
+ h = pre_layers(h)
333
+ # scale and shift for each condition
334
+ for i, (scale, shift) in enumerate(scale_shifts):
335
+ # if scale is None, it indicates that the condition is not provided
336
+ if scale is not None:
337
+ h = h * (biases[i] + scale)
338
+ if shift is not None:
339
+ h = h + shift
340
+ h = mid_layers(h)
341
+
342
+ # upscale or downscale if any just before the last conv
343
+ if up_down_layer is not None:
344
+ h = up_down_layer(h)
345
+ h = post_layers(h)
346
+ return h
347
+
348
+
349
+ class Upsample(nn.Module):
350
+ """
351
+ An upsampling layer with an optional convolution.
352
+
353
+ :param channels: channels in the inputs and outputs.
354
+ :param use_conv: a bool determining if a convolution is applied.
355
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
356
+ upsampling occurs in the inner-two dimensions.
357
+ """
358
+
359
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
360
+ super().__init__()
361
+ self.channels = channels
362
+ self.out_channels = out_channels or channels
363
+ self.use_conv = use_conv
364
+ self.dims = dims
365
+ if use_conv:
366
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
367
+
368
+ def forward(self, x):
369
+ assert x.shape[1] == self.channels
370
+ if self.dims == 3:
371
+ x = F.interpolate(
372
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
373
+ )
374
+ else:
375
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
376
+ if self.use_conv:
377
+ x = self.conv(x)
378
+ return x
379
+
380
+
381
+ class Downsample(nn.Module):
382
+ """
383
+ A downsampling layer with an optional convolution.
384
+
385
+ :param channels: channels in the inputs and outputs.
386
+ :param use_conv: a bool determining if a convolution is applied.
387
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
388
+ downsampling occurs in the inner-two dimensions.
389
+ """
390
+
391
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
392
+ super().__init__()
393
+ self.channels = channels
394
+ self.out_channels = out_channels or channels
395
+ self.use_conv = use_conv
396
+ self.dims = dims
397
+ stride = 2 if dims != 3 else (1, 2, 2)
398
+ if use_conv:
399
+ self.op = conv_nd(
400
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
401
+ )
402
+ else:
403
+ assert self.channels == self.out_channels
404
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
405
+
406
+ def forward(self, x):
407
+ assert x.shape[1] == self.channels
408
+ return self.op(x)
409
+
410
+
411
+ class AttentionBlock(nn.Module):
412
+ """An attention block that allows spatial positions to attend to each other."""
413
+
414
+ def __init__(
415
+ self,
416
+ channels,
417
+ num_heads=1,
418
+ num_head_channels=-1,
419
+ use_checkpoint=False,
420
+ use_new_attention_order=False,
421
+ ):
422
+ super().__init__()
423
+ self.channels = channels
424
+ if num_head_channels == -1:
425
+ self.num_heads = num_heads
426
+ else:
427
+ assert channels % num_head_channels == 0, (
428
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
429
+ )
430
+ self.num_heads = channels // num_head_channels
431
+ self.use_checkpoint = use_checkpoint
432
+ self.norm = normalization(channels)
433
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
434
+ if use_new_attention_order:
435
+ # split qkv before split heads
436
+ self.attention = QKVAttention(self.num_heads)
437
+ else:
438
+ # split heads before split qkv
439
+ self.attention = QKVAttentionLegacy(self.num_heads)
440
+
441
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
442
+
443
+ def forward(self, x):
444
+ return torch_checkpoint(self._forward, (x,), self.use_checkpoint)
445
+
446
+ def _forward(self, x):
447
+ b, c, *spatial = x.shape
448
+ x = x.reshape(b, c, -1)
449
+ qkv = self.qkv(self.norm(x))
450
+ h = self.attention(qkv)
451
+ h = self.proj_out(h)
452
+ return (x + h).reshape(b, c, *spatial)
453
+
454
+
455
+ def count_flops_attn(model, _x, y):
456
+ """
457
+ A counter for the `thop` package to count the operations in an
458
+ attention operation.
459
+ Meant to be used like:
460
+ macs, params = thop.profile(
461
+ model,
462
+ inputs=(inputs, timestamps),
463
+ custom_ops={QKVAttention: QKVAttention.count_flops},
464
+ )
465
+ """
466
+ b, c, *spatial = y[0].shape
467
+ num_spatial = int(np.prod(spatial))
468
+ # We perform two matmuls with the same number of ops.
469
+ # The first computes the weight matrix, the second computes
470
+ # the combination of the value vectors.
471
+ matmul_ops = 2 * b * (num_spatial**2) * c
472
+ model.total_ops += th.DoubleTensor([matmul_ops])
473
+
474
+
475
+ class QKVAttentionLegacy(nn.Module):
476
+ """
477
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
478
+ """
479
+
480
+ def __init__(self, n_heads):
481
+ super().__init__()
482
+ self.n_heads = n_heads
483
+
484
+ def forward(self, qkv):
485
+ """
486
+ Apply QKV attention.
487
+
488
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
489
+ :return: an [N x (H * C) x T] tensor after attention.
490
+ """
491
+ bs, width, length = qkv.shape
492
+ assert width % (3 * self.n_heads) == 0
493
+ ch = width // (3 * self.n_heads)
494
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
495
+ scale = 1 / math.sqrt(math.sqrt(ch))
496
+ weight = th.einsum(
497
+ "bct,bcs->bts", q * scale, k * scale
498
+ ) # More stable with f16 than dividing afterwards
499
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
500
+ a = th.einsum("bts,bcs->bct", weight, v)
501
+ return a.reshape(bs, -1, length)
502
+
503
+ @staticmethod
504
+ def count_flops(model, _x, y):
505
+ return count_flops_attn(model, _x, y)
506
+
507
+
508
+ class QKVAttention(nn.Module):
509
+ """
510
+ A module which performs QKV attention and splits in a different order.
511
+ """
512
+
513
+ def __init__(self, n_heads):
514
+ super().__init__()
515
+ self.n_heads = n_heads
516
+
517
+ def forward(self, qkv):
518
+ """
519
+ Apply QKV attention.
520
+
521
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
522
+ :return: an [N x (H * C) x T] tensor after attention.
523
+ """
524
+ bs, width, length = qkv.shape
525
+ assert width % (3 * self.n_heads) == 0
526
+ ch = width // (3 * self.n_heads)
527
+ q, k, v = qkv.chunk(3, dim=1)
528
+ scale = 1 / math.sqrt(math.sqrt(ch))
529
+ weight = th.einsum(
530
+ "bct,bcs->bts",
531
+ (q * scale).view(bs * self.n_heads, ch, length),
532
+ (k * scale).view(bs * self.n_heads, ch, length),
533
+ ) # More stable with f16 than dividing afterwards
534
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
535
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
536
+ return a.reshape(bs, -1, length)
537
+
538
+ @staticmethod
539
+ def count_flops(model, _x, y):
540
+ return count_flops_attn(model, _x, y)
541
+
542
+
543
+ class AttentionPool2d(nn.Module):
544
+ """
545
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
546
+ """
547
+
548
+ def __init__(
549
+ self,
550
+ spacial_dim: int,
551
+ embed_dim: int,
552
+ num_heads_channels: int,
553
+ output_dim: int = None,
554
+ ):
555
+ super().__init__()
556
+ self.positional_embedding = nn.Parameter(
557
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
558
+ )
559
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
560
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
561
+ self.num_heads = embed_dim // num_heads_channels
562
+ self.attention = QKVAttention(self.num_heads)
563
+
564
+ def forward(self, x):
565
+ b, c, *_spatial = x.shape
566
+ x = x.reshape(b, c, -1) # NC(HW)
567
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
568
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
569
+ x = self.qkv_proj(x)
570
+ x = self.attention(x)
571
+ x = self.c_proj(x)
572
+ return x[:, :, 0]
src/visualizr/model/diffusion.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ from visualizr.model.base import BaseModule
7
+
8
+
9
+ class Mish(BaseModule):
10
+ def forward(self, x):
11
+ return x * torch.tanh(torch.nn.functional.softplus(x))
12
+
13
+
14
+ class Upsample(BaseModule):
15
+ def __init__(self, dim):
16
+ super(Upsample, self).__init__()
17
+ self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
18
+
19
+ def forward(self, x):
20
+ return self.conv(x)
21
+
22
+
23
+ class Downsample(BaseModule):
24
+ def __init__(self, dim):
25
+ super(Downsample, self).__init__()
26
+ self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
27
+
28
+ def forward(self, x):
29
+ return self.conv(x)
30
+
31
+
32
+ class Rezero(BaseModule):
33
+ def __init__(self, fn):
34
+ super(Rezero, self).__init__()
35
+ self.fn = fn
36
+ self.g = torch.nn.Parameter(torch.zeros(1))
37
+
38
+ def forward(self, x):
39
+ return self.fn(x) * self.g
40
+
41
+
42
+ class Block(BaseModule):
43
+ def __init__(self, dim, dim_out, groups=8):
44
+ super(Block, self).__init__()
45
+ self.block = torch.nn.Sequential(
46
+ torch.nn.Conv2d(dim, dim_out, 3, padding=1),
47
+ torch.nn.GroupNorm(groups, dim_out),
48
+ Mish(),
49
+ )
50
+
51
+ def forward(self, x, mask):
52
+ output = self.block(x * mask)
53
+ return output * mask
54
+
55
+
56
+ class ResnetBlock(BaseModule):
57
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
58
+ super(ResnetBlock, self).__init__()
59
+ self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
60
+
61
+ self.block1 = Block(dim, dim_out, groups=groups)
62
+ self.block2 = Block(dim_out, dim_out, groups=groups)
63
+ if dim != dim_out:
64
+ self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
65
+ else:
66
+ self.res_conv = torch.nn.Identity()
67
+
68
+ def forward(self, x, mask, time_emb):
69
+ h = self.block1(x, mask)
70
+ h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
71
+ h = self.block2(h, mask)
72
+ output = h + self.res_conv(x * mask)
73
+ return output
74
+
75
+
76
+ class LinearAttention(BaseModule):
77
+ def __init__(self, dim, heads=4, dim_head=32):
78
+ super(LinearAttention, self).__init__()
79
+ self.heads = heads
80
+ hidden_dim = dim_head * heads
81
+ self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
82
+ self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
83
+
84
+ def forward(self, x):
85
+ b, c, h, w = x.shape
86
+ qkv = self.to_qkv(x)
87
+ q, k, v = rearrange(
88
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
89
+ )
90
+ k = k.softmax(dim=-1)
91
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
92
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
93
+ out = rearrange(
94
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
95
+ )
96
+ return self.to_out(out)
97
+
98
+
99
+ class Residual(BaseModule):
100
+ def __init__(self, fn):
101
+ super(Residual, self).__init__()
102
+ self.fn = fn
103
+
104
+ def forward(self, x, *args, **kwargs):
105
+ output = self.fn(x, *args, **kwargs) + x
106
+ return output
107
+
108
+
109
+ class SinusoidalPosEmb(BaseModule):
110
+ def __init__(self, dim):
111
+ super(SinusoidalPosEmb, self).__init__()
112
+ self.dim = dim
113
+
114
+ def forward(self, x, scale=1000):
115
+ device = x.device
116
+ half_dim = self.dim // 2
117
+ emb = math.log(10000) / (half_dim - 1)
118
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
119
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
120
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
121
+ return emb
122
+
123
+
124
+ class GradLogPEstimator2d(BaseModule):
125
+ def __init__(
126
+ self,
127
+ dim,
128
+ dim_mults=(1, 2, 4),
129
+ groups=8,
130
+ n_spks=None,
131
+ spk_emb_dim=64,
132
+ n_feats=80,
133
+ pe_scale=1000,
134
+ ):
135
+ super(GradLogPEstimator2d, self).__init__()
136
+ self.dim = dim
137
+ self.dim_mults = dim_mults
138
+ self.groups = groups
139
+ self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
140
+ self.spk_emb_dim = spk_emb_dim
141
+ self.pe_scale = pe_scale
142
+
143
+ if n_spks > 1:
144
+ self.spk_mlp = torch.nn.Sequential(
145
+ torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4),
146
+ Mish(),
147
+ torch.nn.Linear(spk_emb_dim * 4, n_feats),
148
+ )
149
+ self.time_pos_emb = SinusoidalPosEmb(dim)
150
+ self.mlp = torch.nn.Sequential(
151
+ torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)
152
+ )
153
+
154
+ dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
155
+ in_out = list(zip(dims[:-1], dims[1:]))
156
+ self.downs = torch.nn.ModuleList([])
157
+ self.ups = torch.nn.ModuleList([])
158
+ num_resolutions = len(in_out)
159
+
160
+ for ind, (dim_in, dim_out) in enumerate(in_out):
161
+ is_last = ind >= (num_resolutions - 1)
162
+ self.downs.append(
163
+ torch.nn.ModuleList(
164
+ [
165
+ ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
166
+ ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
167
+ Residual(Rezero(LinearAttention(dim_out))),
168
+ Downsample(dim_out) if not is_last else torch.nn.Identity(),
169
+ ]
170
+ )
171
+ )
172
+
173
+ mid_dim = dims[-1]
174
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
175
+ self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
176
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
177
+
178
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
179
+ self.ups.append(
180
+ torch.nn.ModuleList(
181
+ [
182
+ ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
183
+ ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
184
+ Residual(Rezero(LinearAttention(dim_in))),
185
+ Upsample(dim_in),
186
+ ]
187
+ )
188
+ )
189
+ self.final_block = Block(dim, dim)
190
+ self.final_conv = torch.nn.Conv2d(dim, 1, 1)
191
+
192
+ def forward(self, x, mask, mu, t, spk=None):
193
+ global s
194
+ if not isinstance(spk, type(None)):
195
+ s = self.spk_mlp(spk)
196
+
197
+ t = self.time_pos_emb(t, scale=self.pe_scale)
198
+ t = self.mlp(t)
199
+
200
+ if self.n_spks < 2:
201
+ x = torch.stack([mu, x], 1)
202
+ else:
203
+ s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
204
+ x = torch.stack([mu, x, s], 1)
205
+ mask = mask.unsqueeze(1)
206
+
207
+ hiddens = []
208
+ masks = [mask]
209
+ for resnet1, resnet2, attn, downsample in self.downs:
210
+ mask_down = masks[-1]
211
+ x = resnet1(x, mask_down, t)
212
+ x = resnet2(x, mask_down, t)
213
+ x = attn(x)
214
+ hiddens.append(x)
215
+ x = downsample(x * mask_down)
216
+ masks.append(mask_down[:, :, :, ::2])
217
+
218
+ masks = masks[:-1]
219
+ mask_mid = masks[-1]
220
+ x = self.mid_block1(x, mask_mid, t)
221
+ x = self.mid_attn(x)
222
+ x = self.mid_block2(x, mask_mid, t)
223
+
224
+ for resnet1, resnet2, attn, upsample in self.ups:
225
+ mask_up = masks.pop()
226
+ x = torch.cat((x, hiddens.pop()), dim=1)
227
+ x = resnet1(x, mask_up, t)
228
+ x = resnet2(x, mask_up, t)
229
+ x = attn(x)
230
+ x = upsample(x * mask_up)
231
+
232
+ x = self.final_block(x, mask)
233
+ output = self.final_conv(x * mask)
234
+
235
+ return (output * mask).squeeze(1)
236
+
237
+
238
+ def get_noise(t, beta_init, beta_term, cumulative=False):
239
+ if cumulative:
240
+ noise = beta_init * t + 0.5 * (beta_term - beta_init) * (t**2)
241
+ else:
242
+ noise = beta_init + (beta_term - beta_init) * t
243
+ return noise
244
+
245
+
246
+ class Diffusion(BaseModule):
247
+ def __init__(
248
+ self,
249
+ n_feats,
250
+ dim,
251
+ n_spks=1,
252
+ spk_emb_dim=64,
253
+ beta_min=0.05,
254
+ beta_max=20,
255
+ pe_scale=1000,
256
+ ):
257
+ super(Diffusion, self).__init__()
258
+ self.n_feats = n_feats
259
+ self.dim = dim
260
+ self.n_spks = n_spks
261
+ self.spk_emb_dim = spk_emb_dim
262
+ self.beta_min = beta_min
263
+ self.beta_max = beta_max
264
+ self.pe_scale = pe_scale
265
+
266
+ self.estimator = GradLogPEstimator2d(
267
+ dim, n_spks=n_spks, spk_emb_dim=spk_emb_dim, pe_scale=pe_scale
268
+ )
269
+
270
+ def forward_diffusion(self, x0, mask, mu, t):
271
+ time = t.unsqueeze(-1).unsqueeze(-1)
272
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
273
+ mean = x0 * torch.exp(-0.5 * cum_noise) + mu * (
274
+ 1.0 - torch.exp(-0.5 * cum_noise)
275
+ )
276
+ variance = 1.0 - torch.exp(-cum_noise)
277
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
278
+ xt = mean + z * torch.sqrt(variance)
279
+ return xt * mask, z * mask
280
+
281
+ @torch.no_grad()
282
+ def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
283
+ h = 1.0 / n_timesteps
284
+ xt = z * mask
285
+ for i in range(n_timesteps):
286
+ t = (1.0 - (i + 0.5) * h) * torch.ones(
287
+ z.shape[0], dtype=z.dtype, device=z.device
288
+ )
289
+ time = t.unsqueeze(-1).unsqueeze(-1)
290
+ noise_t = get_noise(time, self.beta_min, self.beta_max, cumulative=False)
291
+ if stoc: # adds stochastic term
292
+ dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
293
+ dxt_det = dxt_det * noise_t * h
294
+ dxt_stoc = torch.randn(
295
+ z.shape, dtype=z.dtype, device=z.device, requires_grad=False
296
+ )
297
+ dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
298
+ dxt = dxt_det + dxt_stoc
299
+ else:
300
+ dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
301
+ dxt = dxt * noise_t * h
302
+ xt = (xt - dxt) * mask
303
+ return xt
304
+
305
+ @torch.no_grad()
306
+ def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
307
+ return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
308
+
309
+ def loss_t(self, x0, mask, mu, t, spk=None):
310
+ xt, z = self.forward_diffusion(x0, mask, mu, t)
311
+ time = t.unsqueeze(-1).unsqueeze(-1)
312
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
313
+ noise_estimation = self.estimator(xt, mask, mu, t, spk)
314
+ noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
315
+ loss = torch.sum((noise_estimation + z) ** 2) / (torch.sum(mask) * self.n_feats)
316
+ return loss, xt
317
+
318
+ def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
319
+ t = torch.rand(
320
+ x0.shape[0], dtype=x0.dtype, device=x0.device, requires_grad=False
321
+ )
322
+ t = torch.clamp(t, offset, 1.0 - offset)
323
+ return self.loss_t(x0, mask, mu, t, spk)
src/visualizr/model/latentnet.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import NamedTuple, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import init
8
+
9
+ from visualizr.choices import Activation
10
+ from visualizr.config_base import BaseConfig
11
+ from visualizr.model.nn import timestep_embedding
12
+
13
+
14
+ class LatentNetType(Enum):
15
+ none = "none"
16
+ # injecting inputs into the hidden layers
17
+ skip = "skip"
18
+
19
+
20
+ class LatentNetReturn(NamedTuple):
21
+ pred: torch.Tensor = None
22
+
23
+
24
+ @dataclass
25
+ class MLPSkipNetConfig(BaseConfig):
26
+ """
27
+ default MLP for the latent DPM in the paper!
28
+ """
29
+
30
+ num_channels: int
31
+ skip_layers: Tuple[int]
32
+ num_hid_channels: int
33
+ num_layers: int
34
+ num_time_emb_channels: int = 64
35
+ activation: Activation = Activation.silu
36
+ use_norm: bool = True
37
+ condition_bias: float = 1
38
+ dropout: float = 0
39
+ last_act: Activation = Activation.none
40
+ num_time_layers: int = 2
41
+ time_last_act: bool = False
42
+
43
+ def make_model(self):
44
+ return MLPSkipNet(self)
45
+
46
+
47
+ class MLPSkipNet(nn.Module):
48
+ """
49
+ concat x to hidden layers
50
+
51
+ default MLP for the latent DPM in the paper!
52
+ """
53
+
54
+ def __init__(self, conf: MLPSkipNetConfig):
55
+ super().__init__()
56
+ self.conf = conf
57
+
58
+ layers = []
59
+ for i in range(conf.num_time_layers):
60
+ if i == 0:
61
+ a = conf.num_time_emb_channels
62
+ b = conf.num_channels
63
+ else:
64
+ a = conf.num_channels
65
+ b = conf.num_channels
66
+ layers.append(nn.Linear(a, b))
67
+ if i < conf.num_time_layers - 1 or conf.time_last_act:
68
+ layers.append(conf.activation.get_act())
69
+ self.time_embed = nn.Sequential(*layers)
70
+
71
+ self.layers = nn.ModuleList([])
72
+ for i in range(conf.num_layers):
73
+ if i == 0:
74
+ act = conf.activation
75
+ norm = conf.use_norm
76
+ cond = True
77
+ a, b = conf.num_channels, conf.num_hid_channels
78
+ dropout = conf.dropout
79
+ elif i == conf.num_layers - 1:
80
+ act = Activation.none
81
+ norm = False
82
+ cond = False
83
+ a, b = conf.num_hid_channels, conf.num_channels
84
+ dropout = 0
85
+ else:
86
+ act = conf.activation
87
+ norm = conf.use_norm
88
+ cond = True
89
+ a, b = conf.num_hid_channels, conf.num_hid_channels
90
+ dropout = conf.dropout
91
+
92
+ if i in conf.skip_layers:
93
+ a += conf.num_channels
94
+
95
+ self.layers.append(
96
+ MLPLNAct(
97
+ a,
98
+ b,
99
+ norm=norm,
100
+ activation=act,
101
+ cond_channels=conf.num_channels,
102
+ use_cond=cond,
103
+ condition_bias=conf.condition_bias,
104
+ dropout=dropout,
105
+ )
106
+ )
107
+ self.last_act = conf.last_act.get_act()
108
+
109
+ def forward(self, x, t, **kwargs):
110
+ t = timestep_embedding(t, self.conf.num_time_emb_channels)
111
+ cond = self.time_embed(t)
112
+ h = x
113
+ for i in range(len(self.layers)):
114
+ if i in self.conf.skip_layers:
115
+ # injecting input into the hidden layers
116
+ h = torch.cat([h, x], dim=1)
117
+ h = self.layers[i].forward(x=h, cond=cond)
118
+ h = self.last_act(h)
119
+ return LatentNetReturn(h)
120
+
121
+
122
+ class MLPLNAct(nn.Module):
123
+ def __init__(
124
+ self,
125
+ in_channels: int,
126
+ out_channels: int,
127
+ norm: bool,
128
+ use_cond: bool,
129
+ activation: Activation,
130
+ cond_channels: int,
131
+ condition_bias: float = 0,
132
+ dropout: float = 0,
133
+ ):
134
+ super().__init__()
135
+ self.activation = activation
136
+ self.condition_bias = condition_bias
137
+ self.use_cond = use_cond
138
+
139
+ self.linear = nn.Linear(in_channels, out_channels)
140
+ self.act = activation.get_act()
141
+ if self.use_cond:
142
+ self.linear_emb = nn.Linear(cond_channels, out_channels)
143
+ self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144
+ if norm:
145
+ self.norm = nn.LayerNorm(out_channels)
146
+ else:
147
+ self.norm = nn.Identity()
148
+
149
+ if dropout > 0:
150
+ self.dropout = nn.Dropout(p=dropout)
151
+ else:
152
+ self.dropout = nn.Identity()
153
+
154
+ self.init_weights()
155
+
156
+ def init_weights(self):
157
+ for module in self.modules():
158
+ if isinstance(module, nn.Linear):
159
+ if self.activation == Activation.relu:
160
+ init.kaiming_normal_(module.weight, a=0, nonlinearity="relu")
161
+ elif self.activation == Activation.lrelu:
162
+ init.kaiming_normal_(
163
+ module.weight, a=0.2, nonlinearity="leaky_relu"
164
+ )
165
+ elif self.activation == Activation.silu:
166
+ init.kaiming_normal_(module.weight, a=0, nonlinearity="relu")
167
+ else:
168
+ # leave it as default
169
+ pass
170
+
171
+ def forward(self, x, cond=None):
172
+ x = self.linear(x)
173
+ if self.use_cond:
174
+ # (n, c) or (n, c * 2)
175
+ cond = self.cond_layers(cond)
176
+ cond = (cond, None)
177
+
178
+ # scale shift first
179
+ x = x * (self.condition_bias + cond[0])
180
+ if cond[1] is not None:
181
+ x = x + cond[1]
182
+ # then norm
183
+ x = self.norm(x)
184
+ else:
185
+ # no condition
186
+ x = self.norm(x)
187
+ x = self.act(x)
188
+ x = self.dropout(x)
189
+ return x
src/visualizr/model/nn.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch as th
4
+ import torch.nn as nn
5
+ import torch.utils.checkpoint
6
+
7
+
8
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
9
+ class SiLU(nn.Module):
10
+ # @th.jit.script
11
+ def forward(self, x):
12
+ return x * th.sigmoid(x)
13
+
14
+
15
+ class GroupNorm32(nn.GroupNorm):
16
+ def forward(self, x):
17
+ return super().forward(x.float()).type(x.dtype)
18
+
19
+
20
+ def conv_nd(dims, *args, **kwargs):
21
+ """
22
+ Create a 1D, 2D, or 3D convolution module.
23
+ """
24
+ if dims == 1:
25
+ return nn.Conv1d(*args, **kwargs)
26
+ elif dims == 2:
27
+ return nn.Conv2d(*args, **kwargs)
28
+ elif dims == 3:
29
+ return nn.Conv3d(*args, **kwargs)
30
+ raise ValueError(f"unsupported dimensions: {dims}")
31
+
32
+
33
+ def linear(*args, **kwargs):
34
+ """
35
+ Create a linear module.
36
+ """
37
+ return nn.Linear(*args, **kwargs)
38
+
39
+
40
+ def avg_pool_nd(dims, *args, **kwargs):
41
+ """
42
+ Create a 1D, 2D, or 3D average pooling module.
43
+ """
44
+ if dims == 1:
45
+ return nn.AvgPool1d(*args, **kwargs)
46
+ elif dims == 2:
47
+ return nn.AvgPool2d(*args, **kwargs)
48
+ elif dims == 3:
49
+ return nn.AvgPool3d(*args, **kwargs)
50
+ raise ValueError(f"unsupported dimensions: {dims}")
51
+
52
+
53
+ def update_ema(target_params, source_params, rate=0.99):
54
+ """
55
+ Update target parameters to be closer to those of source parameters using
56
+ an exponential moving average.
57
+
58
+ :param target_params: The target parameter sequence.
59
+ :param source_params: The source parameter sequence.
60
+ :param rate: The EMA rate (closer to 1 means slower).
61
+ """
62
+ for targ, src in zip(target_params, source_params):
63
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
64
+
65
+
66
+ def zero_module(module):
67
+ """
68
+ Zero out the parameters of a module and return it.
69
+ """
70
+ for p in module.parameters():
71
+ p.detach().zero_()
72
+ return module
73
+
74
+
75
+ def scale_module(module, scale):
76
+ """
77
+ Scale the parameters of a module and return it.
78
+ """
79
+ for p in module.parameters():
80
+ p.detach().mul_(scale)
81
+ return module
82
+
83
+
84
+ def mean_flat(tensor):
85
+ """
86
+ Take the mean over all non-batch dimensions.
87
+ """
88
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
89
+
90
+
91
+ def normalization(channels):
92
+ """
93
+ Make a standard normalization layer.
94
+
95
+ :param channels: Number of input channels.
96
+ :return: A nn.Module for normalization.
97
+ """
98
+ return GroupNorm32(min(32, channels), channels)
99
+
100
+
101
+ def timestep_embedding(timesteps, dim, max_period=10000):
102
+ """
103
+ Create sinusoidal timestep embeddings.
104
+
105
+ :param timesteps: A 1-D Tensor of N indices, one per batch element.
106
+ These may be fractional.
107
+ :param dim: The dimension of the output.
108
+ :param max_period: Controls the minimum frequency of the embeddings.
109
+ :return: An [N x dim] Tensor of positional embeddings.
110
+ """
111
+ half = dim // 2
112
+ freqs = th.exp(
113
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
114
+ ).to(device=timesteps.device)
115
+ args = timesteps[:, None].float() * freqs[None]
116
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
117
+ if dim % 2:
118
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
119
+ return embedding
120
+
121
+
122
+ def torch_checkpoint(func, args, flag, preserve_rng_state=False):
123
+ # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
124
+ if flag:
125
+ return torch.utils.checkpoint.checkpoint(
126
+ func, *args, preserve_rng_state=preserve_rng_state
127
+ )
128
+ else:
129
+ return func(*args)
src/visualizr/model/seq2seq.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from espnet.nets.pytorch_backend.conformer.encoder import Encoder
4
+ from torch import nn
5
+
6
+ from visualizr import logger
7
+ from visualizr.model.base import BaseModule
8
+
9
+
10
+ class LSTM(nn.Module):
11
+ def __init__(self, motion_dim, output_dim, num_layers=2, hidden_dim=128):
12
+ super().__init__()
13
+ self.lstm = nn.LSTM(
14
+ input_size=motion_dim,
15
+ hidden_size=hidden_dim,
16
+ num_layers=num_layers,
17
+ batch_first=True,
18
+ )
19
+ self.fc = nn.Linear(hidden_dim, output_dim)
20
+
21
+ def forward(self, x):
22
+ x, _ = self.lstm(x)
23
+ return self.fc(x)
24
+
25
+
26
+ class DiffusionPredictor(BaseModule):
27
+ def __init__(self, conf):
28
+ super(DiffusionPredictor, self).__init__()
29
+
30
+ self.infer_type = conf.infer_type
31
+
32
+ self.initialize_layers(conf)
33
+ logger.info(f"infer_type: {self.infer_type}")
34
+
35
+ def create_conformer_encoder(self, attention_dim, num_blocks):
36
+ return Encoder(
37
+ idim=0,
38
+ attention_dim=attention_dim,
39
+ attention_heads=2,
40
+ linear_units=attention_dim,
41
+ num_blocks=num_blocks,
42
+ input_layer=None,
43
+ dropout_rate=0.2,
44
+ positional_dropout_rate=0.2,
45
+ attention_dropout_rate=0.2,
46
+ normalize_before=False,
47
+ concat_after=False,
48
+ positionwise_layer_type="linear",
49
+ positionwise_conv_kernel_size=3,
50
+ macaron_style=True,
51
+ pos_enc_layer_type="rel_pos",
52
+ selfattention_layer_type="rel_selfattn",
53
+ use_cnn_module=True,
54
+ cnn_module_kernel=13,
55
+ )
56
+
57
+ def initialize_layers(
58
+ self,
59
+ conf,
60
+ mfcc_dim=39,
61
+ hubert_dim=1024,
62
+ speech_layers=4,
63
+ speech_dim=512,
64
+ decoder_dim=1024,
65
+ motion_start_dim=512,
66
+ HAL_layers=25,
67
+ ):
68
+ self.conf = conf
69
+ # Speech downsampling
70
+ if self.infer_type.startswith("mfcc"):
71
+ # from 100 hz to 25 hz
72
+ self.down_sample1 = nn.Conv1d(
73
+ mfcc_dim, 256, kernel_size=3, stride=2, padding=1
74
+ )
75
+ self.down_sample2 = nn.Conv1d(
76
+ 256, speech_dim, kernel_size=3, stride=2, padding=1
77
+ )
78
+ elif self.infer_type.startswith("hubert"):
79
+ # from 50 hz to 25 hz
80
+ self.down_sample1 = nn.Conv1d(
81
+ hubert_dim, speech_dim, kernel_size=3, stride=2, padding=1
82
+ )
83
+
84
+ self.weights = nn.Parameter(torch.zeros(HAL_layers))
85
+ self.speech_encoder = self.create_conformer_encoder(
86
+ speech_dim, speech_layers
87
+ )
88
+ else:
89
+ logger.exception("infer_type not supported")
90
+
91
+ # Encoders & Decoders
92
+ self.coarse_decoder = self.create_conformer_encoder(
93
+ decoder_dim, conf.decoder_layers
94
+ )
95
+
96
+ # LSTM predictors for Variance Adapter
97
+ if self.infer_type != "hubert_audio_only":
98
+ self.pose_predictor = LSTM(speech_dim, 3)
99
+ self.pose_encoder = LSTM(3, speech_dim)
100
+
101
+ if "full_control" in self.infer_type:
102
+ self.location_predictor = LSTM(speech_dim, 1)
103
+ self.location_encoder = LSTM(1, speech_dim)
104
+ self.face_scale_predictor = LSTM(speech_dim, 1)
105
+ self.face_scale_encoder = LSTM(1, speech_dim)
106
+
107
+ # Linear transformations
108
+ self.init_code_proj = nn.Sequential(nn.Linear(motion_start_dim, 128))
109
+ self.noisy_encoder = nn.Sequential(nn.Linear(conf.motion_dim, 128))
110
+ self.t_encoder = nn.Sequential(nn.Linear(1, 128))
111
+ self.encoder_direction_code = nn.Linear(conf.motion_dim, 128)
112
+
113
+ self.out_proj = nn.Linear(decoder_dim, conf.motion_dim)
114
+
115
+ def forward(
116
+ self,
117
+ initial_code,
118
+ direction_code,
119
+ seq_input_vector,
120
+ face_location,
121
+ face_scale,
122
+ yaw_pitch_roll,
123
+ noisy_x,
124
+ t_emb,
125
+ control_flag=False,
126
+ ):
127
+ global x
128
+ if self.infer_type.startswith("mfcc"):
129
+ x = self.mfcc_speech_downsample(seq_input_vector)
130
+ elif self.infer_type.startswith("hubert"):
131
+ norm_weights = F.softmax(self.weights, dim=-1)
132
+ weighted_feature = (
133
+ norm_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) * seq_input_vector
134
+ ).sum(dim=1)
135
+ x = self.down_sample1(weighted_feature.transpose(1, 2)).transpose(1, 2)
136
+ x, _ = self.speech_encoder(x, masks=None)
137
+ predicted_location, predicted_scale, predicted_pose = (
138
+ face_location,
139
+ face_scale,
140
+ yaw_pitch_roll,
141
+ )
142
+ if self.infer_type != "hubert_audio_only":
143
+ logger.info(f"pose controllable. control_flag: {control_flag}")
144
+ x, predicted_location, predicted_scale, predicted_pose = (
145
+ self.adjust_features(
146
+ x, face_location, face_scale, yaw_pitch_roll, control_flag
147
+ )
148
+ )
149
+ # Variable initial_code and direction_code serve as a motion guide
150
+ # extracted from the reference image.
151
+ # This aims to tell the model what the starting motion should be.
152
+ concatenated_features = self.combine_features(
153
+ x, initial_code, direction_code, noisy_x, t_emb
154
+ )
155
+ outputs = self.decode_features(concatenated_features)
156
+ return outputs, predicted_location, predicted_scale, predicted_pose
157
+
158
+ def mfcc_speech_downsample(self, seq_input_vector):
159
+ x = self.down_sample1(seq_input_vector.transpose(1, 2))
160
+ return self.down_sample2(x).transpose(1, 2)
161
+
162
+ def adjust_features(
163
+ self, x, face_location, face_scale, yaw_pitch_roll, control_flag
164
+ ):
165
+ predicted_location, predicted_scale = 0, 0
166
+ if "full_control" in self.infer_type:
167
+ logger.info(f"full controllable. control_flag: {control_flag}")
168
+ x_residual, predicted_location = self.adjust_location(
169
+ x, face_location, control_flag
170
+ )
171
+ x = x + x_residual
172
+
173
+ x_residual, predicted_scale = self.adjust_scale(x, face_scale, control_flag)
174
+ x = x + x_residual
175
+
176
+ x_residual, predicted_pose = self.adjust_pose(x, yaw_pitch_roll, control_flag)
177
+ x = x + x_residual
178
+ return x, predicted_location, predicted_scale, predicted_pose
179
+
180
+ def adjust_location(self, x, face_location, control_flag):
181
+ if control_flag:
182
+ predicted_location = face_location
183
+ else:
184
+ predicted_location = self.location_predictor(x)
185
+ return self.location_encoder(predicted_location), predicted_location
186
+
187
+ def adjust_scale(self, x, face_scale, control_flag):
188
+ if control_flag:
189
+ predicted_face_scale = face_scale
190
+ else:
191
+ predicted_face_scale = self.face_scale_predictor(x)
192
+ return self.face_scale_encoder(predicted_face_scale), predicted_face_scale
193
+
194
+ def adjust_pose(self, x, yaw_pitch_roll, control_flag):
195
+ if control_flag:
196
+ predicted_pose = yaw_pitch_roll
197
+ else:
198
+ predicted_pose = self.pose_predictor(x)
199
+ return self.pose_encoder(predicted_pose), predicted_pose
200
+
201
+ def combine_features(self, x, initial_code, direction_code, noisy_x, t_emb):
202
+ init_code_proj = (
203
+ self.init_code_proj(initial_code).unsqueeze(1).repeat(1, x.size(1), 1)
204
+ )
205
+ noisy_feature = self.noisy_encoder(noisy_x)
206
+ t_emb_feature = (
207
+ self.t_encoder(t_emb.unsqueeze(1).float())
208
+ .unsqueeze(1)
209
+ .repeat(1, x.size(1), 1)
210
+ )
211
+ direction_code_feature = (
212
+ self.encoder_direction_code(direction_code)
213
+ .unsqueeze(1)
214
+ .repeat(1, x.size(1), 1)
215
+ )
216
+ return torch.cat(
217
+ (x, direction_code_feature, init_code_proj, noisy_feature, t_emb_feature),
218
+ dim=-1,
219
+ )
220
+
221
+ def decode_features(self, concatenated_features):
222
+ outputs, _ = self.coarse_decoder(concatenated_features, masks=None)
223
+ return self.out_proj(outputs)
src/visualizr/model/unet.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple, Tuple
3
+
4
+ import torch as th
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from visualizr.config_base import BaseConfig
9
+ from visualizr.model.blocks import (
10
+ AttentionBlock,
11
+ Downsample,
12
+ ResBlockConfig,
13
+ TimestepEmbedSequential,
14
+ Upsample,
15
+ )
16
+ from visualizr.model.nn import (
17
+ conv_nd,
18
+ linear,
19
+ normalization,
20
+ timestep_embedding,
21
+ zero_module,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class BeatGANsUNetConfig(BaseConfig):
27
+ image_size: int = 64
28
+ in_channels: int = 3
29
+ # base channels will be multiplied
30
+ model_channels: int = 64
31
+ # output of the unet
32
+ # suggest: 3
33
+ # you only need 6 if you also model the variance of the noise prediction
34
+ # (usually we use an analytical variance hence 3)
35
+ out_channels: int = 3
36
+ # how many repeating resblocks per resolution
37
+ # the decoding side would have "one more" resblock
38
+ # default: 2
39
+ num_res_blocks: int = 2
40
+ # you can also set the number of resblocks specifically for the input blocks
41
+ # default: None = above
42
+ num_input_res_blocks: int = None
43
+ # number of time embed channels and style channels
44
+ embed_channels: int = 512
45
+ # at what resolutions you want to do self-attention of the feature maps
46
+ # attentions generally improve performance
47
+ # default: [16]
48
+ # beatgans: [32, 16, 8]
49
+ attention_resolutions: Tuple[int] = (16,)
50
+ # number of time embed channels
51
+ time_embed_channels: int = None
52
+ # dropout applies to the resblocks (on feature maps)
53
+ dropout: float = 0.1
54
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
55
+ input_channel_mult: Tuple[int] = None
56
+ conv_resample: bool = True
57
+ # always 2 = 2d conv
58
+ dims: int = 2
59
+ # don't use this, legacy from BeatGANs
60
+ num_classes: int = None
61
+ use_checkpoint: bool = False
62
+ # number of attention heads
63
+ num_heads: int = 1
64
+ # or specify the number of channels per attention head
65
+ num_head_channels: int = -1
66
+ # what's this?
67
+ num_heads_upsample: int = -1
68
+ # use resblock for upscale/downscale blocks (expensive)
69
+ # default: True (BeatGANs)
70
+ resblock_updown: bool = True
71
+ # never tried
72
+ use_new_attention_order: bool = False
73
+ resnet_two_cond: bool = False
74
+ resnet_cond_channels: int = None
75
+ # init the decoding conv layers with zero weights, this speeds up training
76
+ # default: True (BeattGANs)
77
+ resnet_use_zero_module: bool = True
78
+ # gradient checkpoint the attention operation
79
+ attn_checkpoint: bool = False
80
+
81
+ def make_model(self):
82
+ return BeatGANsUNetModel(self)
83
+
84
+
85
+ class BeatGANsUNetModel(nn.Module):
86
+ def __init__(self, conf: BeatGANsUNetConfig):
87
+ super().__init__()
88
+ self.conf = conf
89
+
90
+ if conf.num_heads_upsample == -1:
91
+ self.num_heads_upsample = conf.num_heads
92
+
93
+ self.dtype = th.float32
94
+
95
+ self.time_emb_channels = conf.time_embed_channels or conf.model_channels
96
+ self.time_embed = nn.Sequential(
97
+ linear(self.time_emb_channels, conf.embed_channels),
98
+ nn.SiLU(),
99
+ linear(conf.embed_channels, conf.embed_channels),
100
+ )
101
+
102
+ if conf.num_classes is not None:
103
+ self.label_emb = nn.Embedding(conf.num_classes, conf.embed_channels)
104
+
105
+ ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
106
+ self.input_blocks = nn.ModuleList(
107
+ [
108
+ TimestepEmbedSequential(
109
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)
110
+ )
111
+ ]
112
+ )
113
+
114
+ kwargs = dict(
115
+ use_condition=True,
116
+ two_cond=conf.resnet_two_cond,
117
+ use_zero_module=conf.resnet_use_zero_module,
118
+ # style channels for the resnet block
119
+ cond_emb_channels=conf.resnet_cond_channels,
120
+ )
121
+
122
+ self._feature_size = ch
123
+
124
+ # input_block_chans = [ch]
125
+ input_block_chans = [[] for _ in range(len(conf.channel_mult))]
126
+ input_block_chans[0].append(ch)
127
+
128
+ # number of blocks at each resolution
129
+ self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
130
+ self.input_num_blocks[0] = 1
131
+ self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
132
+
133
+ ds = 1
134
+ resolution = conf.image_size
135
+ for level, mult in enumerate(conf.input_channel_mult or conf.channel_mult):
136
+ for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
137
+ layers = [
138
+ ResBlockConfig(
139
+ ch,
140
+ conf.embed_channels,
141
+ conf.dropout,
142
+ out_channels=int(mult * conf.model_channels),
143
+ dims=conf.dims,
144
+ use_checkpoint=conf.use_checkpoint,
145
+ **kwargs,
146
+ ).make_model()
147
+ ]
148
+ ch = int(mult * conf.model_channels)
149
+ if resolution in conf.attention_resolutions:
150
+ layers.append(
151
+ AttentionBlock(
152
+ ch,
153
+ use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
154
+ num_heads=conf.num_heads,
155
+ num_head_channels=conf.num_head_channels,
156
+ use_new_attention_order=conf.use_new_attention_order,
157
+ )
158
+ )
159
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
160
+ self._feature_size += ch
161
+ # input_block_chans.append(ch)
162
+ input_block_chans[level].append(ch)
163
+ self.input_num_blocks[level] += 1
164
+ # print(input_block_chans)
165
+ if level != len(conf.channel_mult) - 1:
166
+ resolution //= 2
167
+ out_ch = ch
168
+ self.input_blocks.append(
169
+ TimestepEmbedSequential(
170
+ ResBlockConfig(
171
+ ch,
172
+ conf.embed_channels,
173
+ conf.dropout,
174
+ out_channels=out_ch,
175
+ dims=conf.dims,
176
+ use_checkpoint=conf.use_checkpoint,
177
+ down=True,
178
+ **kwargs,
179
+ ).make_model()
180
+ if conf.resblock_updown
181
+ else Downsample(ch, conf.conv_resample, conf.dims, out_ch)
182
+ )
183
+ )
184
+ ch = out_ch
185
+ # input_block_chans.append(ch)
186
+ input_block_chans[level + 1].append(ch)
187
+ self.input_num_blocks[level + 1] += 1
188
+ ds *= 2
189
+ self._feature_size += ch
190
+
191
+ self.middle_block = TimestepEmbedSequential(
192
+ ResBlockConfig(
193
+ ch,
194
+ conf.embed_channels,
195
+ conf.dropout,
196
+ dims=conf.dims,
197
+ use_checkpoint=conf.use_checkpoint,
198
+ **kwargs,
199
+ ).make_model(),
200
+ AttentionBlock(
201
+ ch,
202
+ use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
203
+ num_heads=conf.num_heads,
204
+ num_head_channels=conf.num_head_channels,
205
+ use_new_attention_order=conf.use_new_attention_order,
206
+ ),
207
+ ResBlockConfig(
208
+ ch,
209
+ conf.embed_channels,
210
+ conf.dropout,
211
+ dims=conf.dims,
212
+ use_checkpoint=conf.use_checkpoint,
213
+ **kwargs,
214
+ ).make_model(),
215
+ )
216
+ self._feature_size += ch
217
+
218
+ self.output_blocks = nn.ModuleList([])
219
+ for level, mult in list(enumerate(conf.channel_mult))[::-1]:
220
+ for i in range(conf.num_res_blocks + 1):
221
+ # print(input_block_chans)
222
+ # ich = input_block_chans.pop()
223
+ try:
224
+ ich = input_block_chans[level].pop()
225
+ except IndexError:
226
+ # this happens only when num_res_block > num_enc_res_block
227
+ # we will not have enough lateral (skip) connecions for all decoder blocks
228
+ ich = 0
229
+ # print('pop:', ich)
230
+ layers = [
231
+ ResBlockConfig(
232
+ # only direct channels when gated
233
+ channels=ch + ich,
234
+ emb_channels=conf.embed_channels,
235
+ dropout=conf.dropout,
236
+ out_channels=int(conf.model_channels * mult),
237
+ dims=conf.dims,
238
+ use_checkpoint=conf.use_checkpoint,
239
+ # lateral channels are described here when gated
240
+ has_lateral=True if ich > 0 else False,
241
+ lateral_channels=None,
242
+ **kwargs,
243
+ ).make_model()
244
+ ]
245
+ ch = int(conf.model_channels * mult)
246
+ if resolution in conf.attention_resolutions:
247
+ layers.append(
248
+ AttentionBlock(
249
+ ch,
250
+ use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
251
+ num_heads=self.num_heads_upsample,
252
+ num_head_channels=conf.num_head_channels,
253
+ use_new_attention_order=conf.use_new_attention_order,
254
+ )
255
+ )
256
+ if level and i == conf.num_res_blocks:
257
+ resolution *= 2
258
+ out_ch = ch
259
+ layers.append(
260
+ ResBlockConfig(
261
+ ch,
262
+ conf.embed_channels,
263
+ conf.dropout,
264
+ out_channels=out_ch,
265
+ dims=conf.dims,
266
+ use_checkpoint=conf.use_checkpoint,
267
+ up=True,
268
+ **kwargs,
269
+ ).make_model()
270
+ if conf.resblock_updown
271
+ else Upsample(
272
+ ch, conf.conv_resample, dims=conf.dims, out_channels=out_ch
273
+ )
274
+ )
275
+ ds //= 2
276
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
277
+ self.output_num_blocks[level] += 1
278
+ self._feature_size += ch
279
+
280
+ # print(input_block_chans)
281
+ # print('inputs:', self.input_num_blocks)
282
+ # print('outputs:', self.output_num_blocks)
283
+
284
+ if conf.resnet_use_zero_module:
285
+ self.out = nn.Sequential(
286
+ normalization(ch),
287
+ nn.SiLU(),
288
+ zero_module(
289
+ conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1)
290
+ ),
291
+ )
292
+ else:
293
+ self.out = nn.Sequential(
294
+ normalization(ch),
295
+ nn.SiLU(),
296
+ conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
297
+ )
298
+
299
+ def forward(self, x, t, y=None, **kwargs):
300
+ """
301
+ Apply the model to an input batch.
302
+
303
+ :param x: an [N x C x ...] Tensor of inputs.
304
+ :param timesteps: a 1-D batch of timesteps.
305
+ :param y: an [N] Tensor of labels, if class-conditional.
306
+ :return: an [N x C x ...] Tensor of outputs.
307
+ """
308
+ assert (y is not None) == (self.conf.num_classes is not None), (
309
+ "must specify y if and only if the model is class-conditional"
310
+ )
311
+
312
+ # hs = []
313
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
314
+ emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
315
+
316
+ if self.conf.num_classes is not None:
317
+ raise NotImplementedError()
318
+ # assert y.shape == (x.shape[0], )
319
+ # emb = emb + self.label_emb(y)
320
+
321
+ # new code supports input_num_blocks != output_num_blocks
322
+ h = x.type(self.dtype)
323
+ k = 0
324
+ for i in range(len(self.input_num_blocks)):
325
+ for j in range(self.input_num_blocks[i]):
326
+ h = self.input_blocks[k](h, emb=emb)
327
+ # print(i, j, h.shape)
328
+ hs[i].append(h)
329
+ k += 1
330
+ assert k == len(self.input_blocks)
331
+
332
+ h = self.middle_block(h, emb=emb)
333
+ k = 0
334
+ for i in range(len(self.output_num_blocks)):
335
+ for j in range(self.output_num_blocks[i]):
336
+ # take the lateral connection from the same layer (in reserve)
337
+ # until there is no more, use None
338
+ try:
339
+ lateral = hs[-i - 1].pop()
340
+ # print(i, j, lateral.shape)
341
+ except IndexError:
342
+ lateral = None
343
+ # print(i, j, lateral)
344
+ h = self.output_blocks[k](h, emb=emb, lateral=lateral)
345
+ k += 1
346
+
347
+ h = h.type(x.dtype)
348
+ pred = self.out(h)
349
+ return Return(pred=pred)
350
+
351
+
352
+ class Return(NamedTuple):
353
+ pred: th.Tensor
354
+
355
+
356
+ @dataclass
357
+ class BeatGANsEncoderConfig(BaseConfig):
358
+ image_size: int
359
+ in_channels: int
360
+ model_channels: int
361
+ out_hid_channels: int
362
+ out_channels: int
363
+ num_res_blocks: int
364
+ attention_resolutions: Tuple[int]
365
+ dropout: float = 0
366
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
367
+ use_time_condition: bool = True
368
+ conv_resample: bool = True
369
+ dims: int = 2
370
+ use_checkpoint: bool = False
371
+ num_heads: int = 1
372
+ num_head_channels: int = -1
373
+ resblock_updown: bool = False
374
+ use_new_attention_order: bool = False
375
+ pool: str = "adaptivenonzero"
376
+
377
+ def make_model(self):
378
+ return BeatGANsEncoderModel(self)
379
+
380
+
381
+ class BeatGANsEncoderModel(nn.Module):
382
+ """
383
+ The half UNet model with attention and timestep embedding.
384
+
385
+ For usage, see UNet.
386
+ """
387
+
388
+ def __init__(self, conf: BeatGANsEncoderConfig):
389
+ super().__init__()
390
+ self.conf = conf
391
+ self.dtype = th.float32
392
+
393
+ if conf.use_time_condition:
394
+ time_embed_dim = conf.model_channels * 4
395
+ self.time_embed = nn.Sequential(
396
+ linear(conf.model_channels, time_embed_dim),
397
+ nn.SiLU(),
398
+ linear(time_embed_dim, time_embed_dim),
399
+ )
400
+ else:
401
+ time_embed_dim = None
402
+
403
+ ch = int(conf.channel_mult[0] * conf.model_channels)
404
+ self.input_blocks = nn.ModuleList(
405
+ [
406
+ TimestepEmbedSequential(
407
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)
408
+ )
409
+ ]
410
+ )
411
+ self._feature_size = ch
412
+ input_block_chans = [ch]
413
+ ds = 1
414
+ resolution = conf.image_size
415
+ for level, mult in enumerate(conf.channel_mult):
416
+ for _ in range(conf.num_res_blocks):
417
+ layers = [
418
+ ResBlockConfig(
419
+ ch,
420
+ time_embed_dim,
421
+ conf.dropout,
422
+ out_channels=int(mult * conf.model_channels),
423
+ dims=conf.dims,
424
+ use_condition=conf.use_time_condition,
425
+ use_checkpoint=conf.use_checkpoint,
426
+ ).make_model()
427
+ ]
428
+ ch = int(mult * conf.model_channels)
429
+ if resolution in conf.attention_resolutions:
430
+ layers.append(
431
+ AttentionBlock(
432
+ ch,
433
+ use_checkpoint=conf.use_checkpoint,
434
+ num_heads=conf.num_heads,
435
+ num_head_channels=conf.num_head_channels,
436
+ use_new_attention_order=conf.use_new_attention_order,
437
+ )
438
+ )
439
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
440
+ self._feature_size += ch
441
+ input_block_chans.append(ch)
442
+ if level != len(conf.channel_mult) - 1:
443
+ resolution //= 2
444
+ out_ch = ch
445
+ self.input_blocks.append(
446
+ TimestepEmbedSequential(
447
+ ResBlockConfig(
448
+ ch,
449
+ time_embed_dim,
450
+ conf.dropout,
451
+ out_channels=out_ch,
452
+ dims=conf.dims,
453
+ use_condition=conf.use_time_condition,
454
+ use_checkpoint=conf.use_checkpoint,
455
+ down=True,
456
+ ).make_model()
457
+ if (conf.resblock_updown)
458
+ else Downsample(
459
+ ch, conf.conv_resample, dims=conf.dims, out_channels=out_ch
460
+ )
461
+ )
462
+ )
463
+ ch = out_ch
464
+ input_block_chans.append(ch)
465
+ ds *= 2
466
+ self._feature_size += ch
467
+
468
+ self.middle_block = TimestepEmbedSequential(
469
+ ResBlockConfig(
470
+ ch,
471
+ time_embed_dim,
472
+ conf.dropout,
473
+ dims=conf.dims,
474
+ use_condition=conf.use_time_condition,
475
+ use_checkpoint=conf.use_checkpoint,
476
+ ).make_model(),
477
+ AttentionBlock(
478
+ ch,
479
+ use_checkpoint=conf.use_checkpoint,
480
+ num_heads=conf.num_heads,
481
+ num_head_channels=conf.num_head_channels,
482
+ use_new_attention_order=conf.use_new_attention_order,
483
+ ),
484
+ ResBlockConfig(
485
+ ch,
486
+ time_embed_dim,
487
+ conf.dropout,
488
+ dims=conf.dims,
489
+ use_condition=conf.use_time_condition,
490
+ use_checkpoint=conf.use_checkpoint,
491
+ ).make_model(),
492
+ )
493
+ self._feature_size += ch
494
+ if conf.pool == "adaptivenonzero":
495
+ self.out = nn.Sequential(
496
+ normalization(ch),
497
+ nn.SiLU(),
498
+ nn.AdaptiveAvgPool2d((1, 1)),
499
+ conv_nd(conf.dims, ch, conf.out_channels, 1),
500
+ nn.Flatten(),
501
+ )
502
+ else:
503
+ raise NotImplementedError(f"Unexpected {conf.pool} pooling")
504
+
505
+ def forward(self, x, t=None, return_2d_feature=False):
506
+ """
507
+ Apply the model to an input batch.
508
+
509
+ :param x: an [N x C x ...] Tensor of inputs.
510
+ :param timesteps: a 1-D batch of timesteps.
511
+ :return: an [N x K] Tensor of outputs.
512
+ """
513
+ if self.conf.use_time_condition:
514
+ emb = self.time_embed(timestep_embedding(t, self.model_channels))
515
+ else:
516
+ emb = None
517
+
518
+ results = []
519
+ h = x.type(self.dtype)
520
+ for module in self.input_blocks:
521
+ h = module(h, emb=emb)
522
+ if self.conf.pool.startswith("spatial"):
523
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
524
+ h = self.middle_block(h, emb=emb)
525
+ if self.conf.pool.startswith("spatial"):
526
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
527
+ h = th.cat(results, axis=-1)
528
+ else:
529
+ h = h.type(x.dtype)
530
+
531
+ h_2d = h
532
+ h = self.out(h)
533
+
534
+ if return_2d_feature:
535
+ return h, h_2d
536
+ else:
537
+ return h
538
+
539
+ def forward_flatten(self, x):
540
+ """
541
+ transform the last 2d feature into a flatten vector
542
+ """
543
+ h = self.out(x)
544
+ return h
545
+
546
+
547
+ class SuperResModel(BeatGANsUNetModel):
548
+ """
549
+ A UNetModel that performs super-resolution.
550
+
551
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
552
+ """
553
+
554
+ def __init__(self, image_size, in_channels, *args, **kwargs):
555
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
556
+
557
+ def forward(self, x, timesteps, low_res=None, **kwargs):
558
+ _, _, new_height, new_width = x.shape
559
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
560
+ x = th.cat([x, upsampled], dim=1)
561
+ return super().forward(x, timesteps, **kwargs)
src/visualizr/model/unet_autoenc.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+ from visualizr.model import BeatGANsUNetConfig, BeatGANsUNetModel
8
+ from visualizr.model.blocks import ResBlock
9
+ from visualizr.model.latentnet import MLPSkipNetConfig
10
+ from visualizr.model.nn import linear, timestep_embedding
11
+ from visualizr.model.unet import BeatGANsEncoderConfig
12
+
13
+
14
+ @dataclass
15
+ class BeatGANsAutoencConfig(BeatGANsUNetConfig):
16
+ # number of style channels
17
+ enc_out_channels: int = 512
18
+ enc_attn_resolutions: Tuple[int] = None
19
+ enc_pool: str = "depthconv"
20
+ enc_num_res_block: int = 2
21
+ enc_channel_mult: Tuple[int] = None
22
+ enc_grad_checkpoint: bool = False
23
+ latent_net_conf: MLPSkipNetConfig = None
24
+
25
+ def make_model(self):
26
+ return BeatGANsAutoencModel(self)
27
+
28
+
29
+ class BeatGANsAutoencModel(BeatGANsUNetModel):
30
+ def __init__(self, conf: BeatGANsAutoencConfig):
31
+ super().__init__(conf)
32
+ self.conf = conf
33
+
34
+ # having only time, cond
35
+ self.time_embed = TimeStyleSeperateEmbed(
36
+ time_channels=conf.model_channels,
37
+ time_out_channels=conf.embed_channels,
38
+ )
39
+
40
+ self.encoder = BeatGANsEncoderConfig(
41
+ image_size=conf.image_size,
42
+ in_channels=conf.in_channels,
43
+ model_channels=conf.model_channels,
44
+ out_hid_channels=conf.enc_out_channels,
45
+ out_channels=conf.enc_out_channels,
46
+ num_res_blocks=conf.enc_num_res_block,
47
+ attention_resolutions=(
48
+ conf.enc_attn_resolutions or conf.attention_resolutions
49
+ ),
50
+ dropout=conf.dropout,
51
+ channel_mult=conf.enc_channel_mult or conf.channel_mult,
52
+ use_time_condition=False,
53
+ conv_resample=conf.conv_resample,
54
+ dims=conf.dims,
55
+ use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
56
+ num_heads=conf.num_heads,
57
+ num_head_channels=conf.num_head_channels,
58
+ resblock_updown=conf.resblock_updown,
59
+ use_new_attention_order=conf.use_new_attention_order,
60
+ pool=conf.enc_pool,
61
+ ).make_model()
62
+
63
+ if conf.latent_net_conf is not None:
64
+ self.latent_net = conf.latent_net_conf.make_model()
65
+
66
+ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
67
+ """
68
+ Reparameterization trick to sample from N(mu, var) from
69
+ N(0,1).
70
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
71
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
72
+ :return: (Tensor) [B x D]
73
+ """
74
+ assert self.conf.is_stochastic
75
+ std = torch.exp(0.5 * logvar)
76
+ eps = torch.randn_like(std)
77
+ return eps * std + mu
78
+
79
+ def sample_z(self, n: int, device):
80
+ assert self.conf.is_stochastic
81
+ return torch.randn(n, self.conf.enc_out_channels, device=device)
82
+
83
+ def noise_to_cond(self, noise: Tensor):
84
+ raise NotImplementedError()
85
+ # assert self.conf.noise_net_conf is not None
86
+ # return self.noise_net.forward(noise)
87
+
88
+ def encode(self, x):
89
+ cond = self.encoder.forward(x)
90
+ return {"cond": cond}
91
+
92
+ @property
93
+ def stylespace_sizes(self):
94
+ modules = (
95
+ list(self.input_blocks.modules())
96
+ + list(self.middle_block.modules())
97
+ + list(self.output_blocks.modules())
98
+ )
99
+ sizes = []
100
+ for module in modules:
101
+ if isinstance(module, ResBlock):
102
+ linear = module.cond_emb_layers[-1]
103
+ sizes.append(linear.weight.shape[0])
104
+ return sizes
105
+
106
+ def encode_stylespace(self, x, return_vector: bool = True):
107
+ """
108
+ encode to style space
109
+ """
110
+ modules = (
111
+ list(self.input_blocks.modules())
112
+ + list(self.middle_block.modules())
113
+ + list(self.output_blocks.modules())
114
+ )
115
+ # (n, c)
116
+ cond = self.encoder.forward(x)
117
+ S = []
118
+ for module in modules:
119
+ if isinstance(module, ResBlock):
120
+ # (n, c')
121
+ s = module.cond_emb_layers.forward(cond)
122
+ S.append(s)
123
+
124
+ if return_vector:
125
+ # (n, sum_c)
126
+ return torch.cat(S, dim=1)
127
+ else:
128
+ return S
129
+
130
+ def forward(
131
+ self,
132
+ x,
133
+ t,
134
+ y=None,
135
+ x_start=None,
136
+ cond=None,
137
+ style=None,
138
+ noise=None,
139
+ t_cond=None,
140
+ **kwargs,
141
+ ):
142
+ """
143
+ Apply the model to an input batch.
144
+
145
+ Args:
146
+ x_start: the original image to encode
147
+ cond: output of the encoder
148
+ noise: random noise (to predict the cond)
149
+ """
150
+
151
+ if t_cond is None:
152
+ t_cond = t
153
+
154
+ if noise is not None:
155
+ # if the noise is given, we predict the cond from noise
156
+ cond = self.noise_to_cond(noise)
157
+
158
+ if cond is None:
159
+ if x is not None:
160
+ assert len(x) == len(x_start), f"{len(x)} != {len(x_start)}"
161
+
162
+ tmp = self.encode(x_start)
163
+ cond = tmp["cond"]
164
+
165
+ if t is not None:
166
+ _t_emb = timestep_embedding(t, self.conf.model_channels)
167
+ _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
168
+ else:
169
+ # this happens when training only autoenc
170
+ _t_emb = None
171
+ _t_cond_emb = None
172
+
173
+ if self.conf.resnet_two_cond:
174
+ res = self.time_embed.forward(
175
+ time_emb=_t_emb,
176
+ cond=cond,
177
+ time_cond_emb=_t_cond_emb,
178
+ )
179
+ else:
180
+ raise NotImplementedError()
181
+
182
+ if self.conf.resnet_two_cond:
183
+ # two cond: first = time emb, second = cond_emb
184
+ emb = res.time_emb
185
+ cond_emb = res.emb
186
+ else:
187
+ # one cond = combined of both time and cond
188
+ emb = res.emb
189
+ cond_emb = None
190
+
191
+ # override the style if given
192
+ style = style or res.style
193
+
194
+ assert (y is not None) == (self.conf.num_classes is not None), (
195
+ "must specify y if and only if the model is class-conditional"
196
+ )
197
+
198
+ if self.conf.num_classes is not None:
199
+ raise NotImplementedError()
200
+ # assert y.shape == (x.shape[0], )
201
+ # emb = emb + self.label_emb(y)
202
+
203
+ # where in the model to supply time conditions
204
+ enc_time_emb = emb
205
+ mid_time_emb = emb
206
+ dec_time_emb = emb
207
+ # where in the model to supply style conditions
208
+ enc_cond_emb = cond_emb
209
+ mid_cond_emb = cond_emb
210
+ dec_cond_emb = cond_emb
211
+
212
+ # hs = []
213
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
214
+
215
+ if x is not None:
216
+ h = x.type(self.dtype)
217
+
218
+ # input blocks
219
+ k = 0
220
+ for i in range(len(self.input_num_blocks)):
221
+ for j in range(self.input_num_blocks[i]):
222
+ h = self.input_blocks[k](h, emb=enc_time_emb, cond=enc_cond_emb)
223
+
224
+ # print(i, j, h.shape)
225
+ hs[i].append(h)
226
+ k += 1
227
+ assert k == len(self.input_blocks)
228
+
229
+ # middle blocks
230
+ h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
231
+ else:
232
+ # no lateral connections
233
+ # happens when training only the autonecoder
234
+ h = None
235
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
236
+
237
+ # output blocks
238
+ k = 0
239
+ for i in range(len(self.output_num_blocks)):
240
+ for j in range(self.output_num_blocks[i]):
241
+ # take the lateral connection from the same layer (in reserve)
242
+ # until there is no more, use None
243
+ try:
244
+ lateral = hs[-i - 1].pop()
245
+ # print(i, j, lateral.shape)
246
+ except IndexError:
247
+ lateral = None
248
+ # print(i, j, lateral)
249
+
250
+ h = self.output_blocks[k](
251
+ h, emb=dec_time_emb, cond=dec_cond_emb, lateral=lateral
252
+ )
253
+ k += 1
254
+
255
+ pred = self.out(h)
256
+ return AutoencReturn(pred=pred, cond=cond)
257
+
258
+
259
+ class AutoencReturn(NamedTuple):
260
+ pred: Tensor
261
+ cond: Tensor = None
262
+
263
+
264
+ class EmbedReturn(NamedTuple):
265
+ # style and time
266
+ emb: Tensor = None
267
+ # time only
268
+ time_emb: Tensor = None
269
+ # style only (but could depend on time)
270
+ style: Tensor = None
271
+
272
+
273
+ class TimeStyleSeperateEmbed(nn.Module):
274
+ # embed only style
275
+ def __init__(self, time_channels, time_out_channels):
276
+ super().__init__()
277
+ self.time_embed = nn.Sequential(
278
+ linear(time_channels, time_out_channels),
279
+ nn.SiLU(),
280
+ linear(time_out_channels, time_out_channels),
281
+ )
282
+ self.style = nn.Identity()
283
+
284
+ def forward(self, time_emb=None, cond=None, **kwargs):
285
+ if time_emb is None:
286
+ # happens with autoenc training mode
287
+ time_emb = None
288
+ else:
289
+ time_emb = self.time_embed(time_emb)
290
+ style = self.style(cond)
291
+ return EmbedReturn(emb=style, time_emb=time_emb, style=style)
src/visualizr/networks/__init__.py ADDED
File without changes
src/visualizr/networks/discriminator.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
9
+ return F.leaky_relu(input + bias, negative_slope) * scale
10
+
11
+
12
+ class FusedLeakyReLU(nn.Module):
13
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
14
+ super().__init__()
15
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+
19
+ def forward(self, input):
20
+ # print("FusedLeakyReLU: ", input.abs().mean())
21
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
22
+ # print("FusedLeakyReLU: ", out.abs().mean())
23
+ return out
24
+
25
+
26
+ def upfirdn2d_native(
27
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
28
+ ):
29
+ _, minor, in_h, in_w = input.shape
30
+ kernel_h, kernel_w = kernel.shape
31
+
32
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
33
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
34
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
35
+
36
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
37
+ out = out[
38
+ :,
39
+ :,
40
+ max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0),
41
+ max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0),
42
+ ]
43
+
44
+ # out = out.permute(0, 3, 1, 2)
45
+ out = out.reshape(
46
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
47
+ )
48
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
49
+ out = F.conv2d(out, w)
50
+ out = out.reshape(
51
+ -1,
52
+ minor,
53
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
54
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
55
+ )
56
+ # out = out.permute(0, 2, 3, 1)
57
+
58
+ return out[:, :, ::down_y, ::down_x]
59
+
60
+
61
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
62
+ return upfirdn2d_native(
63
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
64
+ )
65
+
66
+
67
+ def make_kernel(k):
68
+ k = torch.tensor(k, dtype=torch.float32)
69
+
70
+ if k.ndim == 1:
71
+ k = k[None, :] * k[:, None]
72
+
73
+ k /= k.sum()
74
+
75
+ return k
76
+
77
+
78
+ class Blur(nn.Module):
79
+ def __init__(self, kernel, pad, upsample_factor=1):
80
+ super().__init__()
81
+
82
+ kernel = make_kernel(kernel)
83
+
84
+ if upsample_factor > 1:
85
+ kernel = kernel * (upsample_factor**2)
86
+
87
+ self.register_buffer("kernel", kernel)
88
+
89
+ self.pad = pad
90
+
91
+ def forward(self, input):
92
+ return upfirdn2d(input, self.kernel, pad=self.pad)
93
+
94
+
95
+ class ScaledLeakyReLU(nn.Module):
96
+ def __init__(self, negative_slope=0.2):
97
+ super().__init__()
98
+
99
+ self.negative_slope = negative_slope
100
+
101
+ def forward(self, input):
102
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
103
+
104
+
105
+ class EqualConv2d(nn.Module):
106
+ def __init__(
107
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
108
+ ):
109
+ super().__init__()
110
+
111
+ self.weight = nn.Parameter(
112
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
113
+ )
114
+ self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
115
+
116
+ self.stride = stride
117
+ self.padding = padding
118
+
119
+ if bias:
120
+ self.bias = nn.Parameter(torch.zeros(out_channel))
121
+ else:
122
+ self.bias = None
123
+
124
+ def forward(self, input):
125
+ return F.conv2d(
126
+ input,
127
+ self.weight * self.scale,
128
+ bias=self.bias,
129
+ stride=self.stride,
130
+ padding=self.padding,
131
+ )
132
+
133
+ def __repr__(self):
134
+ return (
135
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
136
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
137
+ )
138
+
139
+
140
+ class EqualLinear(nn.Module):
141
+ def __init__(
142
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
143
+ ):
144
+ super().__init__()
145
+
146
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
147
+
148
+ if bias:
149
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
150
+ else:
151
+ self.bias = None
152
+
153
+ self.activation = activation
154
+
155
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
156
+ self.lr_mul = lr_mul
157
+
158
+ def forward(self, input):
159
+ if self.activation:
160
+ out = F.linear(input, self.weight * self.scale)
161
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
162
+ else:
163
+ out = F.linear(
164
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
165
+ )
166
+
167
+ return out
168
+
169
+ def __repr__(self):
170
+ return (
171
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
172
+ )
173
+
174
+
175
+ class ConvLayer(nn.Sequential):
176
+ def __init__(
177
+ self,
178
+ in_channel,
179
+ out_channel,
180
+ kernel_size,
181
+ downsample=False,
182
+ blur_kernel=[1, 3, 3, 1],
183
+ bias=True,
184
+ activate=True,
185
+ ):
186
+ layers = []
187
+
188
+ if downsample:
189
+ factor = 2
190
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
191
+ pad0 = (p + 1) // 2
192
+ pad1 = p // 2
193
+
194
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
195
+
196
+ stride = 2
197
+ self.padding = 0
198
+
199
+ else:
200
+ stride = 1
201
+ self.padding = kernel_size // 2
202
+
203
+ layers.append(
204
+ EqualConv2d(
205
+ in_channel,
206
+ out_channel,
207
+ kernel_size,
208
+ padding=self.padding,
209
+ stride=stride,
210
+ bias=bias and not activate,
211
+ )
212
+ )
213
+
214
+ if activate:
215
+ if bias:
216
+ layers.append(FusedLeakyReLU(out_channel))
217
+ else:
218
+ layers.append(ScaledLeakyReLU(0.2))
219
+
220
+ super().__init__(*layers)
221
+
222
+
223
+ class ResBlock(nn.Module):
224
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
225
+ super().__init__()
226
+
227
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
228
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
229
+
230
+ self.skip = ConvLayer(
231
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
232
+ )
233
+
234
+ def forward(self, input):
235
+ out = self.conv1(input)
236
+ out = self.conv2(out)
237
+
238
+ skip = self.skip(input)
239
+ out = (out + skip) / math.sqrt(2)
240
+
241
+ return out
242
+
243
+
244
+ class Discriminator(nn.Module):
245
+ def __init__(self, size, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]):
246
+ super().__init__()
247
+
248
+ self.size = size
249
+
250
+ channels = {
251
+ 4: 512,
252
+ 8: 512,
253
+ 16: 512,
254
+ 32: 512,
255
+ 64: 256 * channel_multiplier,
256
+ 128: 128 * channel_multiplier,
257
+ 256: 64 * channel_multiplier,
258
+ 512: 32 * channel_multiplier,
259
+ 1024: 16 * channel_multiplier,
260
+ }
261
+
262
+ convs = [ConvLayer(3, channels[size], 1)]
263
+ log_size = int(math.log(size, 2))
264
+ in_channel = channels[size]
265
+
266
+ for i in range(log_size, 2, -1):
267
+ out_channel = channels[2 ** (i - 1)]
268
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
269
+ in_channel = out_channel
270
+
271
+ self.convs = nn.Sequential(*convs)
272
+
273
+ self.stddev_group = 4
274
+ self.stddev_feat = 1
275
+
276
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
277
+ self.final_linear = nn.Sequential(
278
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
279
+ EqualLinear(channels[4], 1),
280
+ )
281
+
282
+ def forward(self, input):
283
+ out = self.convs(input)
284
+ batch, channel, height, width = out.shape
285
+
286
+ group = min(batch, self.stddev_group)
287
+ stddev = out.view(
288
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
289
+ )
290
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
291
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
292
+ stddev = stddev.repeat(group, 1, height, width)
293
+ out = torch.cat([out, stddev], 1)
294
+
295
+ out = self.final_conv(out)
296
+
297
+ out = out.view(batch, -1)
298
+ out = self.final_linear(out)
299
+
300
+ return out
src/visualizr/networks/encoder.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from visualizr import logger
8
+
9
+
10
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
11
+ return F.leaky_relu(input + bias, negative_slope) * scale
12
+
13
+
14
+ class FusedLeakyReLU(nn.Module):
15
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
16
+ super().__init__()
17
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
18
+ self.negative_slope = negative_slope
19
+ self.scale = scale
20
+
21
+ def forward(self, input):
22
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
23
+ return out
24
+
25
+
26
+ def upfirdn2d_native(
27
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
28
+ ):
29
+ _, minor, in_h, in_w = input.shape
30
+ kernel_h, kernel_w = kernel.shape
31
+
32
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
33
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
34
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
35
+
36
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
37
+ out = out[
38
+ :,
39
+ :,
40
+ max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0),
41
+ max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0),
42
+ ]
43
+
44
+ out = out.reshape(
45
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
46
+ )
47
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
48
+ out = F.conv2d(out, w)
49
+ out = out.reshape(
50
+ -1,
51
+ minor,
52
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
53
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
54
+ )
55
+
56
+ return out[:, :, ::down_y, ::down_x]
57
+
58
+
59
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
60
+ return upfirdn2d_native(
61
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
62
+ )
63
+
64
+
65
+ def make_kernel(k):
66
+ k = torch.tensor(k, dtype=torch.float32)
67
+
68
+ if k.ndim == 1:
69
+ k = k[None, :] * k[:, None]
70
+
71
+ k /= k.sum()
72
+
73
+ return k
74
+
75
+
76
+ class Blur(nn.Module):
77
+ def __init__(self, kernel, pad, upsample_factor=1):
78
+ super().__init__()
79
+
80
+ kernel = make_kernel(kernel)
81
+
82
+ if upsample_factor > 1:
83
+ kernel = kernel * (upsample_factor**2)
84
+
85
+ self.register_buffer("kernel", kernel)
86
+
87
+ self.pad = pad
88
+
89
+ def forward(self, input):
90
+ return upfirdn2d(input, self.kernel, pad=self.pad)
91
+
92
+
93
+ class ScaledLeakyReLU(nn.Module):
94
+ def __init__(self, negative_slope=0.2):
95
+ super().__init__()
96
+
97
+ self.negative_slope = negative_slope
98
+
99
+ def forward(self, input):
100
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
101
+
102
+
103
+ class EqualConv2d(nn.Module):
104
+ def __init__(
105
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
106
+ ):
107
+ super().__init__()
108
+
109
+ self.weight = nn.Parameter(
110
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
111
+ )
112
+ self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
113
+
114
+ self.stride = stride
115
+ self.padding = padding
116
+
117
+ if bias:
118
+ self.bias = nn.Parameter(torch.zeros(out_channel))
119
+ else:
120
+ self.bias = None
121
+
122
+ def forward(self, input):
123
+ return F.conv2d(
124
+ input,
125
+ self.weight * self.scale,
126
+ bias=self.bias,
127
+ stride=self.stride,
128
+ padding=self.padding,
129
+ )
130
+
131
+ def __repr__(self):
132
+ return (
133
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
134
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
135
+ )
136
+
137
+
138
+ class EqualLinear(nn.Module):
139
+ def __init__(
140
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
141
+ ):
142
+ super().__init__()
143
+
144
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
145
+
146
+ if bias:
147
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
148
+ else:
149
+ self.bias = None
150
+
151
+ self.activation = activation
152
+
153
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
154
+ self.lr_mul = lr_mul
155
+
156
+ def forward(self, input):
157
+ if self.activation:
158
+ out = F.linear(input, self.weight * self.scale)
159
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
160
+ else:
161
+ out = F.linear(
162
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
163
+ )
164
+
165
+ return out
166
+
167
+ def __repr__(self):
168
+ return (
169
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
170
+ )
171
+
172
+
173
+ class ConvLayer(nn.Sequential):
174
+ def __init__(
175
+ self,
176
+ in_channel,
177
+ out_channel,
178
+ kernel_size,
179
+ downsample=False,
180
+ blur_kernel=[1, 3, 3, 1],
181
+ bias=True,
182
+ activate=True,
183
+ ):
184
+ layers = []
185
+
186
+ if downsample:
187
+ factor = 2
188
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
189
+ pad0 = (p + 1) // 2
190
+ pad1 = p // 2
191
+
192
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
193
+
194
+ stride = 2
195
+ self.padding = 0
196
+
197
+ else:
198
+ stride = 1
199
+ self.padding = kernel_size // 2
200
+
201
+ layers.append(
202
+ EqualConv2d(
203
+ in_channel,
204
+ out_channel,
205
+ kernel_size,
206
+ padding=self.padding,
207
+ stride=stride,
208
+ bias=bias and not activate,
209
+ )
210
+ )
211
+
212
+ if activate:
213
+ if bias:
214
+ layers.append(FusedLeakyReLU(out_channel))
215
+ else:
216
+ layers.append(ScaledLeakyReLU(0.2))
217
+
218
+ super().__init__(*layers)
219
+
220
+
221
+ class ResBlock(nn.Module):
222
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
223
+ super().__init__()
224
+
225
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
226
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
227
+
228
+ self.skip = ConvLayer(
229
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
230
+ )
231
+
232
+ def forward(self, input):
233
+ out = self.conv1(input)
234
+ out = self.conv2(out)
235
+
236
+ skip = self.skip(input)
237
+ out = (out + skip) / math.sqrt(2)
238
+
239
+ return out
240
+
241
+
242
+ class WeightedSumLayer(nn.Module):
243
+ def __init__(self, num_tensors=8):
244
+ super(WeightedSumLayer, self).__init__()
245
+
246
+ self.weights = nn.Parameter(torch.randn(num_tensors))
247
+
248
+ def forward(self, tensor_list):
249
+ weights = torch.softmax(self.weights, dim=0)
250
+ weighted_sum = torch.zeros_like(tensor_list[0])
251
+ for tensor, weight in zip(tensor_list, weights):
252
+ weighted_sum += tensor * weight
253
+
254
+ return weighted_sum
255
+
256
+
257
+ class EncoderApp(nn.Module):
258
+ def __init__(self, size, w_dim=512, fusion_type=""):
259
+ super(EncoderApp, self).__init__()
260
+
261
+ channels = {
262
+ 4: 512,
263
+ 8: 512,
264
+ 16: 512,
265
+ 32: 512,
266
+ 64: 256,
267
+ 128: 128,
268
+ 256: 64,
269
+ 512: 32,
270
+ 1024: 16,
271
+ }
272
+
273
+ self.w_dim = w_dim
274
+ log_size = int(math.log(size, 2))
275
+
276
+ self.convs = nn.ModuleList()
277
+ self.convs.append(ConvLayer(3, channels[size], 1))
278
+
279
+ in_channel = channels[size]
280
+ for i in range(log_size, 2, -1):
281
+ out_channel = channels[2 ** (i - 1)]
282
+ self.convs.append(ResBlock(in_channel, out_channel))
283
+ in_channel = out_channel
284
+
285
+ self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
286
+
287
+ self.fusion_type = fusion_type
288
+ assert self.fusion_type == "weighted_sum"
289
+ if self.fusion_type == "weighted_sum":
290
+ logger.info("HAL layer is enabled!")
291
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
292
+ self.fc1 = EqualLinear(64, 512)
293
+ self.fc2 = EqualLinear(128, 512)
294
+ self.fc3 = EqualLinear(256, 512)
295
+ self.ws = WeightedSumLayer()
296
+
297
+ def forward(self, x):
298
+ res = []
299
+ h = x
300
+ pooled_h_lists = []
301
+ for i, conv in enumerate(self.convs):
302
+ h = conv(h)
303
+ if self.fusion_type == "weighted_sum":
304
+ pooled_h = self.adaptive_pool(h).view(x.size(0), -1)
305
+ if i == 0:
306
+ pooled_h_lists.append(self.fc1(pooled_h))
307
+ elif i == 1:
308
+ pooled_h_lists.append(self.fc2(pooled_h))
309
+ elif i == 2:
310
+ pooled_h_lists.append(self.fc3(pooled_h))
311
+ else:
312
+ pooled_h_lists.append(pooled_h)
313
+ res.append(h)
314
+
315
+ if self.fusion_type == "weighted_sum":
316
+ last_layer = self.ws(pooled_h_lists)
317
+ else:
318
+ last_layer = res[-1].squeeze(-1).squeeze(-1)
319
+ layer_features = res[::-1][2:]
320
+
321
+ return last_layer, layer_features
322
+
323
+
324
+ class DecouplingModel(nn.Module):
325
+ def __init__(self, input_dim, hidden_dim, output_dim):
326
+ super(DecouplingModel, self).__init__()
327
+
328
+ # identity_excluded_net is called identity encoder in the paper
329
+ self.identity_net = nn.Sequential(
330
+ nn.Linear(input_dim, hidden_dim),
331
+ nn.ReLU(),
332
+ nn.Linear(hidden_dim, output_dim),
333
+ )
334
+
335
+ self.identity_net_density = nn.Sequential(
336
+ nn.Linear(input_dim, hidden_dim),
337
+ nn.ReLU(),
338
+ nn.Linear(hidden_dim, output_dim),
339
+ )
340
+
341
+ # identity_excluded_net is called motion encoder in the paper
342
+ self.identity_excluded_net = nn.Sequential(
343
+ nn.Linear(input_dim, hidden_dim),
344
+ nn.ReLU(),
345
+ nn.Linear(hidden_dim, output_dim),
346
+ )
347
+
348
+ def forward(self, x):
349
+ id_, id_rm = self.identity_net(x), self.identity_excluded_net(x)
350
+ id_density = self.identity_net_density(id_)
351
+ return id_, id_rm, id_density
352
+
353
+
354
+ class Encoder(nn.Module):
355
+ def __init__(self, size, dim=512, dim_motion=20, weighted_sum=False):
356
+ super(Encoder, self).__init__()
357
+
358
+ # image encoder
359
+ self.net_app = EncoderApp(size, dim, weighted_sum)
360
+
361
+ # decouping network
362
+ self.net_decouping = DecouplingModel(dim, dim, dim)
363
+
364
+ # part of the motion encoder
365
+ fc = [EqualLinear(dim, dim)]
366
+ for i in range(3):
367
+ fc.append(EqualLinear(dim, dim))
368
+
369
+ fc.append(EqualLinear(dim, dim_motion))
370
+ self.fc = nn.Sequential(*fc)
371
+
372
+ def enc_app(self, x):
373
+ h_source = self.net_app(x)
374
+
375
+ return h_source
376
+
377
+ def enc_motion(self, x):
378
+ h, _ = self.net_app(x)
379
+ h_motion = self.fc(h)
380
+
381
+ return h_motion
382
+
383
+ def encode_image_obj(self, image_obj):
384
+ feat, _ = self.net_app(image_obj)
385
+ id_emb, idrm_emb, id_density_emb = self.net_decouping(feat)
386
+ return id_emb, idrm_emb, id_density_emb
387
+
388
+ def forward(self, input_source, input_target, input_face, input_aug):
389
+ if input_target is not None:
390
+ h_source, feats = self.net_app(input_source)
391
+ h_target, _ = self.net_app(input_target)
392
+ h_face, _ = self.net_app(input_face)
393
+ h_aug, _ = self.net_app(input_aug)
394
+
395
+ h_source_id_emb, h_source_idrm_emb, h_source_id_density_emb = (
396
+ self.net_decouping(h_source)
397
+ )
398
+ h_target_id_emb, h_target_idrm_emb, h_target_id_density_emb = (
399
+ self.net_decouping(h_target)
400
+ )
401
+ h_face_id_emb, h_face_idrm_emb, h_face_id_density_emb = self.net_decouping(
402
+ h_face
403
+ )
404
+ h_aug_id_emb, h_aug_idrm_emb, h_aug_id_density_emb = self.net_decouping(
405
+ h_aug
406
+ )
407
+
408
+ h_target_motion_target = self.fc(h_target_idrm_emb)
409
+ h_another_face_target = self.fc(h_face_idrm_emb)
410
+
411
+ else:
412
+ h_source, feats = self.net_app(input_source)
413
+
414
+ return {
415
+ "h_source": h_source,
416
+ "h_motion": h_target_motion_target,
417
+ "feats": feats,
418
+ "h_another_face_target": h_another_face_target,
419
+ "h_face": h_face,
420
+ "h_source_id_emb": h_source_id_emb,
421
+ "h_source_idrm_emb": h_source_idrm_emb,
422
+ "h_source_id_density_emb": h_source_id_density_emb,
423
+ "h_target_id_emb": h_target_id_emb,
424
+ "h_target_idrm_emb": h_target_idrm_emb,
425
+ "h_target_id_density_emb": h_target_id_density_emb,
426
+ "h_face_id_emb": h_face_id_emb,
427
+ "h_face_idrm_emb": h_face_idrm_emb,
428
+ "h_face_id_density_emb": h_face_id_density_emb,
429
+ "h_aug_id_emb": h_aug_id_emb,
430
+ "h_aug_idrm_emb": h_aug_idrm_emb,
431
+ "h_aug_id_density_emb": h_aug_id_density_emb,
432
+ }
src/visualizr/networks/generator.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from visualizr.networks.encoder import Encoder
4
+ from visualizr.networks.styledecoder import Synthesis
5
+
6
+
7
+ class Generator(nn.Module):
8
+ def __init__(
9
+ self,
10
+ size,
11
+ style_dim=512,
12
+ motion_dim=20,
13
+ channel_multiplier=1,
14
+ blur_kernel=[1, 3, 3, 1],
15
+ ):
16
+ super(Generator, self).__init__()
17
+
18
+ # encoder
19
+ self.enc = Encoder(size, style_dim, motion_dim)
20
+ self.dec = Synthesis(
21
+ size, style_dim, motion_dim, blur_kernel, channel_multiplier
22
+ )
23
+
24
+ def get_direction(self):
25
+ return self.dec.direction(None)
26
+
27
+ def synthesis(self, wa, alpha, feat):
28
+ img = self.dec(wa, alpha, feat)
29
+
30
+ return img
31
+
32
+ def forward(self, img_source, img_drive, h_start=None):
33
+ wa, alpha, feats = self.enc(img_source, img_drive, h_start)
34
+ # import pdb;pdb.set_trace()
35
+ img_recon = self.dec(wa, alpha, feats)
36
+
37
+ return img_recon
src/visualizr/networks/styledecoder.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
10
+ return F.leaky_relu(input + bias, negative_slope) * scale
11
+
12
+
13
+ class FusedLeakyReLU(nn.Module):
14
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
15
+ super().__init__()
16
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
17
+ self.negative_slope = negative_slope
18
+ self.scale = scale
19
+
20
+ def forward(self, input):
21
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
22
+ return out
23
+
24
+
25
+ def upfirdn2d_native(
26
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
27
+ ):
28
+ _, minor, in_h, in_w = input.shape
29
+ kernel_h, kernel_w = kernel.shape
30
+
31
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
32
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
33
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
34
+
35
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
36
+ out = out[
37
+ :,
38
+ :,
39
+ max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0),
40
+ max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0),
41
+ ]
42
+
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ return out[:, :, ::down_y, ::down_x]
55
+
56
+
57
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
58
+ return upfirdn2d_native(
59
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
60
+ )
61
+
62
+
63
+ class PixelNorm(nn.Module):
64
+ def __init__(self):
65
+ super().__init__()
66
+
67
+ def forward(self, input):
68
+ return input * torch.rsqrt(torch.mean(input**2, dim=1, keepdim=True) + 1e-8)
69
+
70
+
71
+ class MotionPixelNorm(nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+
75
+ def forward(self, input):
76
+ return input * torch.rsqrt(torch.mean(input**2, dim=2, keepdim=True) + 1e-8)
77
+
78
+
79
+ def make_kernel(k):
80
+ k = torch.tensor(k, dtype=torch.float32)
81
+
82
+ if k.ndim == 1:
83
+ k = k[None, :] * k[:, None]
84
+
85
+ k /= k.sum()
86
+
87
+ return k
88
+
89
+
90
+ class Upsample(nn.Module):
91
+ def __init__(self, kernel, factor=2):
92
+ super().__init__()
93
+
94
+ self.factor = factor
95
+ kernel = make_kernel(kernel) * (factor**2)
96
+ self.register_buffer("kernel", kernel)
97
+
98
+ p = kernel.shape[0] - factor
99
+
100
+ pad0 = (p + 1) // 2 + factor - 1
101
+ pad1 = p // 2
102
+
103
+ self.pad = (pad0, pad1)
104
+
105
+ def forward(self, input):
106
+ return upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
107
+
108
+
109
+ class Downsample(nn.Module):
110
+ def __init__(self, kernel, factor=2):
111
+ super().__init__()
112
+
113
+ self.factor = factor
114
+ kernel = make_kernel(kernel)
115
+ self.register_buffer("kernel", kernel)
116
+
117
+ p = kernel.shape[0] - factor
118
+
119
+ pad0 = (p + 1) // 2
120
+ pad1 = p // 2
121
+
122
+ self.pad = (pad0, pad1)
123
+
124
+ def forward(self, input):
125
+ return upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
126
+
127
+
128
+ class Blur(nn.Module):
129
+ def __init__(self, kernel, pad, upsample_factor=1):
130
+ super().__init__()
131
+
132
+ kernel = make_kernel(kernel)
133
+
134
+ if upsample_factor > 1:
135
+ kernel = kernel * (upsample_factor**2)
136
+
137
+ self.register_buffer("kernel", kernel)
138
+
139
+ self.pad = pad
140
+
141
+ def forward(self, input):
142
+ return upfirdn2d(input, self.kernel, pad=self.pad)
143
+
144
+
145
+ class EqualConv2d(nn.Module):
146
+ def __init__(
147
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
148
+ ):
149
+ super().__init__()
150
+
151
+ self.weight = nn.Parameter(
152
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
153
+ )
154
+ self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
155
+
156
+ self.stride = stride
157
+ self.padding = padding
158
+
159
+ if bias:
160
+ self.bias = nn.Parameter(torch.zeros(out_channel))
161
+ else:
162
+ self.bias = None
163
+
164
+ def forward(self, input):
165
+ return F.conv2d(
166
+ input,
167
+ self.weight * self.scale,
168
+ bias=self.bias,
169
+ stride=self.stride,
170
+ padding=self.padding,
171
+ )
172
+
173
+ def __repr__(self):
174
+ return (
175
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
176
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
177
+ )
178
+
179
+
180
+ class EqualLinear(nn.Module):
181
+ def __init__(
182
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
183
+ ):
184
+ super().__init__()
185
+
186
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
187
+
188
+ if bias:
189
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
190
+ else:
191
+ self.bias = None
192
+
193
+ self.activation = activation
194
+
195
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
196
+ self.lr_mul = lr_mul
197
+
198
+ def forward(self, input):
199
+ if self.activation:
200
+ out = F.linear(input, self.weight * self.scale)
201
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
202
+ else:
203
+ out = F.linear(
204
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
205
+ )
206
+
207
+ return out
208
+
209
+ def __repr__(self):
210
+ return (
211
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
212
+ )
213
+
214
+
215
+ class ScaledLeakyReLU(nn.Module):
216
+ def __init__(self, negative_slope=0.2):
217
+ super().__init__()
218
+
219
+ self.negative_slope = negative_slope
220
+
221
+ def forward(self, input):
222
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
223
+
224
+
225
+ class ModulatedConv2d(nn.Module):
226
+ def __init__(
227
+ self,
228
+ in_channel,
229
+ out_channel,
230
+ kernel_size,
231
+ style_dim,
232
+ demodulate=True,
233
+ upsample=False,
234
+ downsample=False,
235
+ blur_kernel=[1, 3, 3, 1],
236
+ ):
237
+ super().__init__()
238
+
239
+ self.eps = 1e-8
240
+ self.kernel_size = kernel_size
241
+ self.in_channel = in_channel
242
+ self.out_channel = out_channel
243
+ self.upsample = upsample
244
+ self.downsample = downsample
245
+
246
+ if upsample:
247
+ factor = 2
248
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
249
+ pad0 = (p + 1) // 2 + factor - 1
250
+ pad1 = p // 2 + 1
251
+
252
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
253
+
254
+ if downsample:
255
+ factor = 2
256
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
257
+ pad0 = (p + 1) // 2
258
+ pad1 = p // 2
259
+
260
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
261
+
262
+ fan_in = in_channel * kernel_size**2
263
+ self.scale = 1 / math.sqrt(fan_in)
264
+ self.padding = kernel_size // 2
265
+
266
+ self.weight = nn.Parameter(
267
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
268
+ )
269
+
270
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
271
+ self.demodulate = demodulate
272
+
273
+ def __repr__(self):
274
+ return (
275
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
276
+ f"upsample={self.upsample}, downsample={self.downsample})"
277
+ )
278
+
279
+ def forward(self, input, style):
280
+ batch, in_channel, height, width = input.shape
281
+
282
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
283
+ weight = self.scale * self.weight * style
284
+
285
+ if self.demodulate:
286
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
287
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
288
+
289
+ weight = weight.view(
290
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
291
+ )
292
+
293
+ if self.upsample:
294
+ input = input.view(1, batch * in_channel, height, width)
295
+ weight = weight.view(
296
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
297
+ )
298
+ weight = weight.transpose(1, 2).reshape(
299
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
300
+ )
301
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
302
+ _, _, height, width = out.shape
303
+ out = out.view(batch, self.out_channel, height, width)
304
+ out = self.blur(out)
305
+ elif self.downsample:
306
+ input = self.blur(input)
307
+ _, _, height, width = input.shape
308
+ input = input.view(1, batch * in_channel, height, width)
309
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
310
+ _, _, height, width = out.shape
311
+ out = out.view(batch, self.out_channel, height, width)
312
+ else:
313
+ input = input.view(1, batch * in_channel, height, width)
314
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
315
+ _, _, height, width = out.shape
316
+ out = out.view(batch, self.out_channel, height, width)
317
+
318
+ return out
319
+
320
+
321
+ class NoiseInjection(nn.Module):
322
+ def __init__(self):
323
+ super().__init__()
324
+
325
+ self.weight = nn.Parameter(torch.zeros(1))
326
+
327
+ def forward(self, image, noise=None):
328
+ if noise is None:
329
+ return image
330
+ else:
331
+ return image + self.weight * noise
332
+
333
+
334
+ class ConstantInput(nn.Module):
335
+ def __init__(self, channel, size=4):
336
+ super().__init__()
337
+
338
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
339
+
340
+ def forward(self, input):
341
+ batch = input.shape[0]
342
+ out = self.input.repeat(batch, 1, 1, 1)
343
+
344
+ return out
345
+
346
+
347
+ class StyledConv(nn.Module):
348
+ def __init__(
349
+ self,
350
+ in_channel,
351
+ out_channel,
352
+ kernel_size,
353
+ style_dim,
354
+ upsample=False,
355
+ blur_kernel=[1, 3, 3, 1],
356
+ demodulate=True,
357
+ ):
358
+ super().__init__()
359
+
360
+ self.conv = ModulatedConv2d(
361
+ in_channel,
362
+ out_channel,
363
+ kernel_size,
364
+ style_dim,
365
+ upsample=upsample,
366
+ blur_kernel=blur_kernel,
367
+ demodulate=demodulate,
368
+ )
369
+
370
+ self.noise = NoiseInjection()
371
+ self.activate = FusedLeakyReLU(out_channel)
372
+
373
+ def forward(self, input, style, noise=None):
374
+ out = self.conv(input, style)
375
+ out = self.noise(out, noise=noise)
376
+ out = self.activate(out)
377
+
378
+ return out
379
+
380
+
381
+ class ConvLayer(nn.Sequential):
382
+ def __init__(
383
+ self,
384
+ in_channel,
385
+ out_channel,
386
+ kernel_size,
387
+ downsample=False,
388
+ blur_kernel=[1, 3, 3, 1],
389
+ bias=True,
390
+ activate=True,
391
+ ):
392
+ layers = []
393
+
394
+ if downsample:
395
+ factor = 2
396
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
397
+ pad0 = (p + 1) // 2
398
+ pad1 = p // 2
399
+
400
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
401
+
402
+ stride = 2
403
+ self.padding = 0
404
+
405
+ else:
406
+ stride = 1
407
+ self.padding = kernel_size // 2
408
+
409
+ layers.append(
410
+ EqualConv2d(
411
+ in_channel,
412
+ out_channel,
413
+ kernel_size,
414
+ padding=self.padding,
415
+ stride=stride,
416
+ bias=bias and not activate,
417
+ )
418
+ )
419
+
420
+ if activate:
421
+ if bias:
422
+ layers.append(FusedLeakyReLU(out_channel))
423
+ else:
424
+ layers.append(ScaledLeakyReLU(0.2))
425
+
426
+ super().__init__(*layers)
427
+
428
+
429
+ class ToRGB(nn.Module):
430
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
431
+ super().__init__()
432
+
433
+ if upsample:
434
+ self.upsample = Upsample(blur_kernel)
435
+
436
+ self.conv = ConvLayer(in_channel, 3, 1)
437
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
438
+
439
+ def forward(self, input, skip=None):
440
+ out = self.conv(input)
441
+ out = out + self.bias
442
+
443
+ if skip is not None:
444
+ skip = self.upsample(skip)
445
+ out = out + skip
446
+
447
+ return out
448
+
449
+
450
+ class ToFlow(nn.Module):
451
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
452
+ super().__init__()
453
+
454
+ if upsample:
455
+ self.upsample = Upsample(blur_kernel)
456
+
457
+ self.style_dim = style_dim
458
+ self.in_channel = in_channel
459
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
460
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
461
+
462
+ def forward(
463
+ self, input, style, feat, skip=None
464
+ ): # input 是来自上一层的 feature, style 是 512 的 condition, feat 是来自于 unet 的跳层
465
+ out = self.conv(input, style)
466
+ out = out + self.bias
467
+
468
+ # warping
469
+ xs = np.linspace(-1, 1, input.size(2))
470
+
471
+ xs = np.meshgrid(xs, xs)
472
+ xs = np.stack(xs, 2)
473
+
474
+ xs = (
475
+ torch.tensor(xs, requires_grad=False)
476
+ .float()
477
+ .unsqueeze(0)
478
+ .repeat(input.size(0), 1, 1, 1)
479
+ .to(input.device)
480
+ )
481
+ # import pdb;pdb.set_trace()
482
+ if skip is not None:
483
+ skip = self.upsample(skip)
484
+ out = out + skip
485
+
486
+ sampler = torch.tanh(out[:, 0:2, :, :])
487
+ mask = torch.sigmoid(out[:, 2:3, :, :])
488
+ flow = sampler.permute(0, 2, 3, 1) + xs # xs在这里相当于一个 location 的位置
489
+
490
+ feat_warp = F.grid_sample(feat, flow) * mask
491
+ # import pdb;pdb.set_trace()
492
+ return feat_warp, feat_warp + input * (1.0 - mask), out
493
+
494
+
495
+ class Direction(nn.Module):
496
+ def __init__(self, motion_dim):
497
+ super(Direction, self).__init__()
498
+
499
+ self.weight = nn.Parameter(torch.randn(512, motion_dim))
500
+
501
+ def forward(self, input):
502
+ # input: (bs*t) x 512
503
+
504
+ weight = self.weight + 1e-8
505
+ Q, R = torch.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4]
506
+
507
+ if input is None:
508
+ return Q
509
+ else:
510
+ input_diag = torch.diag_embed(input) # alpha, diagonal matrix
511
+ out = torch.matmul(input_diag, Q.T)
512
+ out = torch.sum(out, dim=1)
513
+ return out
514
+
515
+
516
+ class Synthesis(nn.Module):
517
+ def __init__(
518
+ self,
519
+ size,
520
+ style_dim,
521
+ motion_dim,
522
+ blur_kernel=[1, 3, 3, 1],
523
+ channel_multiplier=1,
524
+ ):
525
+ super(Synthesis, self).__init__()
526
+
527
+ self.size = size
528
+ self.style_dim = style_dim
529
+ self.motion_dim = motion_dim
530
+
531
+ self.direction = Direction(
532
+ motion_dim
533
+ ) # Linear Motion Decomposition (LMD) from LIA
534
+
535
+ self.channels = {
536
+ 4: 512,
537
+ 8: 512,
538
+ 16: 512,
539
+ 32: 512,
540
+ 64: 256 * channel_multiplier,
541
+ 128: 128 * channel_multiplier,
542
+ 256: 64 * channel_multiplier,
543
+ 512: 32 * channel_multiplier,
544
+ 1024: 16 * channel_multiplier,
545
+ }
546
+
547
+ self.input = ConstantInput(self.channels[4])
548
+ self.conv1 = StyledConv(
549
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
550
+ )
551
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
552
+
553
+ self.log_size = int(math.log(size, 2))
554
+ self.num_layers = (self.log_size - 2) * 2 + 1
555
+
556
+ self.convs = nn.ModuleList()
557
+ self.upsamples = nn.ModuleList()
558
+ self.to_rgbs = nn.ModuleList()
559
+ self.to_flows = nn.ModuleList()
560
+
561
+ in_channel = self.channels[4]
562
+
563
+ for i in range(3, self.log_size + 1):
564
+ out_channel = self.channels[2**i]
565
+
566
+ self.convs.append(
567
+ StyledConv(
568
+ in_channel,
569
+ out_channel,
570
+ 3,
571
+ style_dim,
572
+ upsample=True,
573
+ blur_kernel=blur_kernel,
574
+ )
575
+ )
576
+ self.convs.append(
577
+ StyledConv(
578
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
579
+ )
580
+ )
581
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
582
+
583
+ self.to_flows.append(ToFlow(out_channel, style_dim))
584
+
585
+ in_channel = out_channel
586
+
587
+ self.n_latent = self.log_size * 2 - 2
588
+
589
+ def forward(self, source_before_decoupling, target_motion, feats):
590
+ global skip_flow, skip
591
+ directions = self.direction(target_motion)
592
+ latent = source_before_decoupling + directions # wa + directions
593
+
594
+ inject_index = self.n_latent
595
+ latent = latent.unsqueeze(1).repeat(1, inject_index, 1)
596
+
597
+ out = self.input(latent)
598
+ out = self.conv1(out, latent[:, 0])
599
+
600
+ i = 1
601
+ for conv1, conv2, to_rgb, to_flow, feat in zip(
602
+ self.convs[::2], self.convs[1::2], self.to_rgbs, self.to_flows, feats
603
+ ):
604
+ out = conv1(out, latent[:, i])
605
+ out = conv2(out, latent[:, i + 1])
606
+ if out.size(2) == 8:
607
+ out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat)
608
+ skip = to_rgb(out_warp)
609
+ else:
610
+ out_warp, out, skip_flow = to_flow(
611
+ out, latent[:, i + 2], feat, skip_flow
612
+ )
613
+ skip = to_rgb(out_warp, skip)
614
+ i += 2
615
+
616
+ img = skip
617
+
618
+ return img
src/visualizr/networks/utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class AntiAliasInterpolation2d(nn.Module):
7
+ """
8
+ Band-limited downsampling, for better preservation of the input signal.
9
+ """
10
+
11
+ def __init__(self, channels, scale):
12
+ super(AntiAliasInterpolation2d, self).__init__()
13
+ sigma = (1 / scale - 1) / 2
14
+ kernel_size = 2 * round(sigma * 4) + 1
15
+ self.ka = kernel_size // 2
16
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
17
+
18
+ kernel_size = [kernel_size, kernel_size]
19
+ sigma = [sigma, sigma]
20
+ # The gaussian kernel is the product of the gaussian function of each dimension.
21
+ kernel = 1
22
+ meshgrids = torch.meshgrid(
23
+ [torch.arange(size, dtype=torch.float32) for size in kernel_size]
24
+ )
25
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
26
+ mean = (size - 1) / 2
27
+ kernel *= torch.exp(-((mgrid - mean) ** 2) / (2 * std**2))
28
+
29
+ # Make sure sum of values in gaussian kernel equals 1.
30
+ kernel /= torch.sum(kernel)
31
+ # Reshape to depthwise convolutional weight
32
+ kernel = kernel.view(1, 1, *kernel.size())
33
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
34
+
35
+ self.register_buffer("weight", kernel)
36
+ self.groups = channels
37
+ self.scale = scale
38
+ inv_scale = 1 / scale
39
+ self.int_inv_scale = int(inv_scale)
40
+
41
+ def forward(self, input):
42
+ if self.scale == 1.0:
43
+ return input
44
+
45
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
46
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
47
+ out = out[:, :, :: self.int_inv_scale, :: self.int_inv_scale]
48
+
49
+ return out
src/visualizr/renderer.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from visualizr.choices import TrainMode
2
+ from visualizr.config import TrainConfig
3
+
4
+
5
+ def render_condition(
6
+ conf: TrainConfig,
7
+ model,
8
+ sampler,
9
+ start,
10
+ motion_direction_start,
11
+ audio_driven,
12
+ face_location,
13
+ face_scale,
14
+ yaw_pitch_roll,
15
+ noisy_t,
16
+ control_flag,
17
+ ):
18
+ if conf.train_mode == TrainMode.diffusion:
19
+ assert conf.model_type.has_autoenc()
20
+
21
+ return sampler.sample(
22
+ model=model,
23
+ noise=noisy_t,
24
+ model_kwargs={
25
+ "motion_direction_start": motion_direction_start,
26
+ "yaw_pitch_roll": yaw_pitch_roll,
27
+ "start": start,
28
+ "audio_driven": audio_driven,
29
+ "face_location": face_location,
30
+ "face_scale": face_scale,
31
+ "control_flag": control_flag,
32
+ },
33
+ )
34
+ else:
35
+ raise NotImplementedError()
src/visualizr/settings.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel
4
+ from torch.cuda import is_available
5
+
6
+
7
+ class Args(BaseModel):
8
+ test_image_path: str
9
+ test_audio_path: str
10
+ test_hubert_path: str
11
+ result_path: str = "./outputs/"
12
+ stage1_checkpoint_path: str = "ckpts/stage1.ckpt"
13
+ stage2_checkpoint_path: str
14
+ control_flag: bool = True
15
+ pose_driven_path: str = "not_supported_in_this_mode"
16
+ image_size: int = 256
17
+ device: Literal["cuda", "cpu"] = "cuda" if is_available() else "cpu"
18
+ motion_dim: int = 20
19
+ decoder_layers: int = 2
20
+
21
+
22
+ class DefaultValues(BaseModel):
23
+ pose_yaw: float = 0.0
24
+ pose_pitch: float = 0.0
25
+ pose_roll: float = 0.0
26
+ face_location: float = 0.5
27
+ face_scale: float = 0.5
28
+ step_T: int = 50
29
+ seed: int = 0
src/visualizr/templates.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from visualizr.choices import GenerativeType, ModelName
2
+ from visualizr.config import PretrainConfig, TrainConfig
3
+
4
+
5
+ def ddpm():
6
+ """
7
+ base configuration for all DDIM-based models.
8
+ """
9
+ conf = TrainConfig()
10
+ conf.batch_size = 32
11
+ conf.beatgans_gen_type = GenerativeType.ddim
12
+ conf.beta_scheduler = "linear"
13
+ conf.data_name = "ffhq"
14
+ conf.diffusion_type = "beatgans"
15
+ conf.eval_ema_every_samples = 200_000
16
+ conf.eval_every_samples = 200_000
17
+ conf.fp16 = True
18
+ conf.lr = 1e-4
19
+ conf.model_name = ModelName.beatgans_ddpm
20
+ conf.net_attn = (16,)
21
+ conf.net_beatgans_attn_head = 1
22
+ conf.net_beatgans_embed_channels = 512
23
+ conf.net_ch_mult = (1, 2, 4, 8)
24
+ conf.net_ch = 64
25
+ conf.sample_size = 32
26
+ conf.T_eval = 20
27
+ conf.T = 1000
28
+ conf.make_model_conf()
29
+ return conf
30
+
31
+
32
+ def autoenc_base():
33
+ """
34
+ base configuration for all Diff-AE models.
35
+ """
36
+ conf = TrainConfig()
37
+ conf.batch_size = 32
38
+ conf.beatgans_gen_type = GenerativeType.ddim
39
+ conf.beta_scheduler = "linear"
40
+ conf.data_name = "ffhq"
41
+ conf.diffusion_type = "beatgans"
42
+ conf.eval_ema_every_samples = 200_000
43
+ conf.eval_every_samples = 200_000
44
+ conf.fp16 = True
45
+ conf.lr = 1e-4
46
+ conf.model_name = ModelName.beatgans_autoenc
47
+ conf.net_attn = (16,)
48
+ conf.net_beatgans_attn_head = 1
49
+ conf.net_beatgans_embed_channels = 512
50
+ conf.net_beatgans_resnet_two_cond = True
51
+ conf.net_ch_mult = (1, 2, 4, 8)
52
+ conf.net_ch = 64
53
+ conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
54
+ conf.net_enc_pool = "adaptivenonzero"
55
+ conf.sample_size = 32
56
+ conf.T_eval = 20
57
+ conf.T = 1000
58
+ conf.make_model_conf()
59
+ return conf
60
+
61
+
62
+ def ffhq64_ddpm():
63
+ conf = ddpm()
64
+ conf.data_name = "ffhqlmdb256"
65
+ conf.warmup = 0
66
+ conf.total_samples = 72_000_000
67
+ conf.scale_up_gpus(4)
68
+ return conf
69
+
70
+
71
+ def ffhq64_autoenc():
72
+ conf = autoenc_base()
73
+ conf.data_name = "ffhqlmdb256"
74
+ conf.warmup = 0
75
+ conf.total_samples = 72_000_000
76
+ conf.net_ch_mult = (1, 2, 4, 8)
77
+ conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
78
+ conf.eval_every_samples = 1_000_000
79
+ conf.eval_ema_every_samples = 1_000_000
80
+ conf.scale_up_gpus(4)
81
+ conf.make_model_conf()
82
+ return conf
83
+
84
+
85
+ def celeba64d2c_ddpm():
86
+ conf = ffhq128_ddpm()
87
+ conf.data_name = "celebalmdb"
88
+ conf.eval_every_samples = 10_000_000
89
+ conf.eval_ema_every_samples = 10_000_000
90
+ conf.total_samples = 72_000_000
91
+ conf.name = "celeba64d2c_ddpm"
92
+ return conf
93
+
94
+
95
+ def celeba64d2c_autoenc():
96
+ conf = ffhq64_autoenc()
97
+ conf.data_name = "celebalmdb"
98
+ conf.eval_every_samples = 10_000_000
99
+ conf.eval_ema_every_samples = 10_000_000
100
+ conf.total_samples = 72_000_000
101
+ conf.name = "celeba64d2c_autoenc"
102
+ return conf
103
+
104
+
105
+ def ffhq128_ddpm():
106
+ conf = ddpm()
107
+ conf.data_name = "ffhqlmdb256"
108
+ conf.warmup = 0
109
+ conf.total_samples = 48_000_000
110
+ conf.img_size = 128
111
+ conf.net_ch = 128
112
+ # channels:
113
+ # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4
114
+ # sizes:
115
+ # 128 => 128 => 64 => 32 => 16 => 8
116
+ conf.net_ch_mult = (1, 1, 2, 3, 4)
117
+ conf.eval_every_samples = 1_000_000
118
+ conf.eval_ema_every_samples = 1_000_000
119
+ conf.scale_up_gpus(4)
120
+ conf.eval_ema_every_samples = 10_000_000
121
+ conf.eval_every_samples = 10_000_000
122
+ conf.make_model_conf()
123
+ return conf
124
+
125
+
126
+ def ffhq128_autoenc_base():
127
+ conf = autoenc_base()
128
+ conf.data_name = "ffhqlmdb256"
129
+ conf.scale_up_gpus(4)
130
+ conf.img_size = 128
131
+ conf.net_ch = 128
132
+ # final resolution = 8x8
133
+ conf.net_ch_mult = (1, 1, 2, 3, 4)
134
+ # final resolution = 4x4
135
+ conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
136
+ conf.eval_ema_every_samples = 10_000_000
137
+ conf.eval_every_samples = 10_000_000
138
+ conf.make_model_conf()
139
+ return conf
140
+
141
+
142
+ def ffhq256_autoenc():
143
+ conf = ffhq128_autoenc_base()
144
+ conf.img_size = 256
145
+ conf.net_ch = 128
146
+ conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
147
+ conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
148
+ conf.eval_every_samples = 10_000_000
149
+ conf.eval_ema_every_samples = 10_000_000
150
+ conf.total_samples = 200_000_000
151
+ conf.batch_size = 64
152
+ conf.make_model_conf()
153
+ conf.name = "ffhq256_autoenc"
154
+ return conf
155
+
156
+
157
+ def ffhq256_autoenc_eco():
158
+ conf = ffhq128_autoenc_base()
159
+ conf.img_size = 256
160
+ conf.net_ch = 128
161
+ conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
162
+ conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
163
+ conf.eval_every_samples = 10_000_000
164
+ conf.eval_ema_every_samples = 10_000_000
165
+ conf.total_samples = 200_000_000
166
+ conf.batch_size = 64
167
+ conf.make_model_conf()
168
+ conf.name = "ffhq256_autoenc_eco"
169
+ return conf
170
+
171
+
172
+ def ffhq128_ddpm_72M():
173
+ conf = ffhq128_ddpm()
174
+ conf.total_samples = 72_000_000
175
+ conf.name = "ffhq128_ddpm_72M"
176
+ return conf
177
+
178
+
179
+ def ffhq128_autoenc_72M():
180
+ conf = ffhq128_autoenc_base()
181
+ conf.total_samples = 72_000_000
182
+ conf.name = "ffhq128_autoenc_72M"
183
+ return conf
184
+
185
+
186
+ def ffhq128_ddpm_130M():
187
+ conf = ffhq128_ddpm()
188
+ conf.total_samples = 130_000_000
189
+ conf.eval_ema_every_samples = 10_000_000
190
+ conf.eval_every_samples = 10_000_000
191
+ conf.name = "ffhq128_ddpm_130M"
192
+ return conf
193
+
194
+
195
+ def ffhq128_autoenc_130M():
196
+ conf = ffhq128_autoenc_base()
197
+ conf.total_samples = 130_000_000
198
+ conf.eval_ema_every_samples = 10_000_000
199
+ conf.eval_every_samples = 10_000_000
200
+ conf.name = "ffhq128_autoenc_130M"
201
+ return conf
202
+
203
+
204
+ def horse128_ddpm():
205
+ conf = ffhq128_ddpm()
206
+ conf.data_name = "horse256"
207
+ conf.total_samples = 130_000_000
208
+ conf.eval_ema_every_samples = 10_000_000
209
+ conf.eval_every_samples = 10_000_000
210
+ conf.name = "horse128_ddpm"
211
+ return conf
212
+
213
+
214
+ def horse128_autoenc():
215
+ conf = ffhq128_autoenc_base()
216
+ conf.data_name = "horse256"
217
+ conf.total_samples = 130_000_000
218
+ conf.eval_ema_every_samples = 10_000_000
219
+ conf.eval_every_samples = 10_000_000
220
+ conf.name = "horse128_autoenc"
221
+ return conf
222
+
223
+
224
+ def bedroom128_ddpm():
225
+ conf = ffhq128_ddpm()
226
+ conf.data_name = "bedroom256"
227
+ conf.eval_ema_every_samples = 10_000_000
228
+ conf.eval_every_samples = 10_000_000
229
+ conf.total_samples = 120_000_000
230
+ conf.name = "bedroom128_ddpm"
231
+ return conf
232
+
233
+
234
+ def bedroom128_autoenc():
235
+ conf = ffhq128_autoenc_base()
236
+ conf.data_name = "bedroom256"
237
+ conf.eval_ema_every_samples = 10_000_000
238
+ conf.eval_every_samples = 10_000_000
239
+ conf.total_samples = 120_000_000
240
+ conf.name = "bedroom128_autoenc"
241
+ return conf
242
+
243
+
244
+ def pretrain_celeba64d2c_72M():
245
+ conf = celeba64d2c_autoenc()
246
+ conf.pretrain = PretrainConfig(
247
+ name="72M",
248
+ path=f"checkpoints/{celeba64d2c_autoenc().name}/last.ckpt",
249
+ )
250
+ conf.latent_infer_path = f"checkpoints/{celeba64d2c_autoenc().name}/latent.pkl"
251
+ return conf
252
+
253
+
254
+ def pretrain_ffhq128_autoenc72M():
255
+ conf = ffhq128_autoenc_base()
256
+ conf.postfix = ""
257
+ conf.pretrain = PretrainConfig(
258
+ name="72M",
259
+ path=f"checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt",
260
+ )
261
+ conf.latent_infer_path = f"checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl"
262
+ return conf
263
+
264
+
265
+ def pretrain_ffhq128_autoenc130M():
266
+ conf = ffhq128_autoenc_base()
267
+ conf.pretrain = PretrainConfig(
268
+ name="130M",
269
+ path=f"checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt",
270
+ )
271
+ conf.latent_infer_path = f"checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl"
272
+ return conf
273
+
274
+
275
+ def pretrain_ffhq256_autoenc():
276
+ conf = ffhq256_autoenc()
277
+ conf.pretrain = PretrainConfig(
278
+ name="90M",
279
+ path=f"checkpoints/{ffhq256_autoenc().name}/last.ckpt",
280
+ )
281
+ conf.latent_infer_path = f"checkpoints/{ffhq256_autoenc().name}/latent.pkl"
282
+ return conf
283
+
284
+
285
+ def pretrain_horse128():
286
+ conf = horse128_autoenc()
287
+ conf.pretrain = PretrainConfig(
288
+ name="82M",
289
+ path=f"checkpoints/{horse128_autoenc().name}/last.ckpt",
290
+ )
291
+ conf.latent_infer_path = f"checkpoints/{horse128_autoenc().name}/latent.pkl"
292
+ return conf
293
+
294
+
295
+ def pretrain_bedroom128():
296
+ conf = bedroom128_autoenc()
297
+ conf.pretrain = PretrainConfig(
298
+ name="120M",
299
+ path=f"checkpoints/{bedroom128_autoenc().name}/last.ckpt",
300
+ )
301
+ conf.latent_infer_path = f"checkpoints/{bedroom128_autoenc().name}/latent.pkl"
302
+ return conf
src/visualizr/utils.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import sys
4
+ import time
5
+ from importlib.util import find_spec
6
+ from pathlib import Path
7
+ from typing import Literal
8
+
9
+ import gradio as gr
10
+ import librosa
11
+ import numpy as np
12
+ import python_speech_features
13
+ import spaces
14
+ import torch
15
+ from gradio import Markdown
16
+ from moviepy.editor import (
17
+ AudioFileClip,
18
+ ImageClip,
19
+ VideoFileClip,
20
+ concatenate_videoclips,
21
+ )
22
+ from PIL import Image
23
+ from torch import Tensor
24
+ from torchvision.transforms import ToPILImage
25
+ from tqdm import tqdm
26
+
27
+ from visualizr import (
28
+ FRAMES_RESULT_SAVED_PATH,
29
+ MOTION_DIM,
30
+ RESULTS_DIR,
31
+ STAGE_1_CHECKPOINT_PATH,
32
+ TMP_MP4,
33
+ logger,
34
+ model_mapping,
35
+ )
36
+ from visualizr.config import TrainConfig
37
+ from visualizr.experiment import LitModel
38
+ from visualizr.LIA_Model import LIA_Model
39
+ from visualizr.templates import ffhq256_autoenc
40
+
41
+
42
+ def check_package_installed(package_name: str) -> bool:
43
+ return find_spec(package_name) is not None
44
+
45
+
46
+ def frames_to_video(input_path, audio_path, output_path, fps=25):
47
+ image_files = [
48
+ os.path.join(input_path, img) for img in sorted(os.listdir(input_path))
49
+ ]
50
+ clips = [ImageClip(m).set_duration(1 / fps) for m in image_files]
51
+ video = concatenate_videoclips(clips, method="compose")
52
+ audio = AudioFileClip(audio_path)
53
+ final_video = video.set_audio(audio)
54
+ final_video.write_videofile(output_path, fps, "libx264", audio_codec="aac")
55
+
56
+
57
+ def load_image(filename: str, size: int) -> np.ndarray:
58
+ img: Image.Image = Image.open(filename).convert("RGB")
59
+ img_resized: Image.Image = img.resize((size, size))
60
+ img_np: np.ndarray = np.asarray(img_resized)
61
+ img_transposed: np.ndarray = np.transpose(img_np, (2, 0, 1)) # 3 x 256 x 256
62
+ return img_transposed / 255.0
63
+
64
+
65
+ def img_preprocessing(img_path: str, size: int) -> Tensor:
66
+ img_np: np.ndarray = load_image(img_path, size) # [0, 1]
67
+ img: Tensor = torch.from_numpy(img_np).unsqueeze(0).float() # [0, 1]
68
+ normalized_image: Tensor = (img - 0.5) * 2.0 # [-1, 1]
69
+ return normalized_image
70
+
71
+
72
+ def saved_image(img_tensor: Tensor, img_path: str) -> None:
73
+ pil_image_converter: ToPILImage = ToPILImage()
74
+ img = pil_image_converter(img_tensor.detach().cpu().squeeze(0))
75
+ img.save(img_path)
76
+
77
+
78
+ def load_stage_1_model() -> LIA_Model:
79
+ logger.info("Loading stage 1 model... ")
80
+ lia: LIA_Model = LIA_Model(motion_dim=MOTION_DIM, fusion_type="weighted_sum")
81
+ lia.load_lightning_model(STAGE_1_CHECKPOINT_PATH)
82
+ lia.to("cuda")
83
+ return lia
84
+
85
+
86
+ def load_stage_2_model(conf: TrainConfig, stage2_checkpoint_path: str) -> LitModel:
87
+ logger.info("Loading stage 2 model... ")
88
+ model = LitModel(conf)
89
+ state = torch.load(stage2_checkpoint_path, "cpu")
90
+ model.load_state_dict(state)
91
+ model.ema_model.eval()
92
+ model.ema_model.to("cuda")
93
+ return model
94
+
95
+
96
+ def init_conf(
97
+ infer_type: Literal[
98
+ "mfcc_full_control",
99
+ "mfcc_pose_only",
100
+ "hubert_pose_only",
101
+ "hubert_audio_only",
102
+ "hubert_full_control",
103
+ ],
104
+ seed: int,
105
+ ) -> TrainConfig:
106
+ logger.info("Initializing configuration... ")
107
+ conf: TrainConfig = ffhq256_autoenc()
108
+ conf.seed = seed
109
+ conf.decoder_layers = 2
110
+ conf.infer_type = infer_type
111
+ conf.motion_dim = MOTION_DIM
112
+ logger.info(f"infer_type: {infer_type}")
113
+ match infer_type:
114
+ case "mfcc_full_control":
115
+ conf.face_location = True
116
+ conf.face_scale = True
117
+ conf.mfcc = True
118
+ case "mfcc_pose_only":
119
+ conf.face_location = False
120
+ conf.face_scale = False
121
+ conf.mfcc = True
122
+ case "hubert_pose_only":
123
+ conf.face_location = False
124
+ conf.face_scale = False
125
+ conf.mfcc = False
126
+ case "hubert_audio_only":
127
+ conf.face_location = False
128
+ conf.face_scale = False
129
+ conf.mfcc = False
130
+ case "hubert_full_control":
131
+ conf.face_location = True
132
+ conf.face_scale = True
133
+ conf.mfcc = False
134
+ return conf
135
+
136
+
137
+ def main(
138
+ infer_type: Literal[
139
+ "mfcc_full_control",
140
+ "mfcc_pose_only",
141
+ "hubert_pose_only",
142
+ "hubert_audio_only",
143
+ "hubert_full_control",
144
+ ],
145
+ image_path: str,
146
+ test_audio_path: str,
147
+ face_sr: bool,
148
+ pose_yaw: float,
149
+ pose_pitch: float,
150
+ pose_roll: float,
151
+ face_location: float,
152
+ face_scale: float,
153
+ step_t: int,
154
+ seed: int,
155
+ stage2_checkpoint_path: str,
156
+ ):
157
+ global frame_end, audio_driven
158
+ if not os.path.exists(image_path):
159
+ logger.exception(f"{image_path} does not exist!")
160
+ sys.exit(0)
161
+ if not os.path.exists(test_audio_path):
162
+ logger.exception(f"{test_audio_path} does not exist!")
163
+ sys.exit(0)
164
+
165
+ image_name: str = Path(image_path).stem
166
+ audio_name: str = Path(test_audio_path).stem
167
+
168
+ predicted_video_256_path: Path = RESULTS_DIR / f"{image_name}-{audio_name}.mp4"
169
+ predicted_video_512_path: Path = RESULTS_DIR / f"{image_name}-{audio_name}_SR.mp4"
170
+
171
+ # ======Loading Stage 1 model=========
172
+ lia: LIA_Model = load_stage_1_model()
173
+ # ============================
174
+
175
+ conf: TrainConfig = init_conf(infer_type, seed)
176
+
177
+ img_source: Tensor = img_preprocessing(image_path, 256).to("cuda")
178
+ one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(
179
+ img_source, img_source, img_source, img_source
180
+ )
181
+
182
+ # ======Loading Stage 2 model=========
183
+ model = load_stage_2_model(conf, stage2_checkpoint_path)
184
+ # =================================
185
+
186
+ # ======Audio Input=========
187
+ if conf.infer_type.startswith("mfcc"):
188
+ # MFCC features
189
+ wav, sr = librosa.load(test_audio_path, sr=16000)
190
+ input_values = python_speech_features.mfcc(
191
+ signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01
192
+ )
193
+ d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
194
+ d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
195
+ audio_driven_obj: np.ndarray = np.hstack(
196
+ (input_values, d_mfcc_feat, d_mfcc_feat2)
197
+ )
198
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[0] / 4)
199
+ audio_start, audio_end = (
200
+ int(frame_start * 4),
201
+ int(frame_end * 4),
202
+ ) # The video frame is fixed to 25 hz, and the audio is fixed to 100 hz
203
+
204
+ audio_driven = (
205
+ torch.Tensor(audio_driven_obj[audio_start:audio_end, :])
206
+ .unsqueeze(0)
207
+ .float()
208
+ .to("cuda")
209
+ )
210
+
211
+ elif conf.infer_type.startswith("hubert"):
212
+ # Hubert features
213
+ if not check_package_installed("transformers"):
214
+ logger.exception("Please install transformers module first.")
215
+ sys.exit(0)
216
+ hubert_model_path = "ckpts/chinese-hubert-large"
217
+ if not os.path.exists(hubert_model_path):
218
+ logger.exception(
219
+ "Please download the hubert weight into the ckpts path first."
220
+ )
221
+ sys.exit(0)
222
+ logger.info(
223
+ "You did not extract the audio features in advance, "
224
+ + "extracting online now, which will increase processing delay"
225
+ )
226
+
227
+ start_time = time.time()
228
+
229
+ # load hubert model
230
+ from transformers import HubertModel, Wav2Vec2FeatureExtractor
231
+
232
+ audio_model = HubertModel.from_pretrained(hubert_model_path).to("cuda")
233
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
234
+ audio_model.feature_extractor._freeze_parameters() # skipcq: PYL-W0212
235
+ audio_model.eval()
236
+
237
+ # hubert model forward pass
238
+ audio, sr = librosa.load(test_audio_path, sr=16000)
239
+ input_values = feature_extractor(
240
+ audio,
241
+ sampling_rate=16000,
242
+ padding=True,
243
+ do_normalize=True,
244
+ return_tensors="pt",
245
+ ).input_values
246
+ input_values = input_values.to("cuda")
247
+ ws_feats = []
248
+ with torch.no_grad():
249
+ outputs = audio_model(input_values, output_hidden_states=True)
250
+ for i in range(len(outputs.hidden_states)):
251
+ ws_feats.append(outputs.hidden_states[i].detach().cpu().numpy())
252
+ ws_feat_obj = np.array(ws_feats)
253
+ ws_feat_obj = np.squeeze(ws_feat_obj, 1)
254
+ ws_feat_obj = np.pad(
255
+ ws_feat_obj, ((0, 0), (0, 1), (0, 0)), "edge"
256
+ ) # align the audio length with the video frame
257
+
258
+ execution_time = time.time() - start_time
259
+ logger.info(f"Extraction Audio Feature: {execution_time:.2f} Seconds")
260
+
261
+ audio_driven_obj = ws_feat_obj
262
+
263
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[1] / 2)
264
+ audio_start, audio_end = (
265
+ int(frame_start * 2),
266
+ int(frame_end * 2),
267
+ ) # The video frame is fixed to 25 hz, and the audio is fixed to 50 hz
268
+
269
+ audio_driven = (
270
+ torch.Tensor(audio_driven_obj[:, audio_start:audio_end, :])
271
+ .unsqueeze(0)
272
+ .float()
273
+ .to("cuda")
274
+ )
275
+ # ============================
276
+
277
+ # Diffusion Noise
278
+ noisy_t = torch.randn((1, frame_end, MOTION_DIM)).to("cuda")
279
+
280
+ # ======Inputs for Attribute Control=========
281
+ yaw_signal = torch.zeros(1, frame_end, 1).to("cuda") + pose_yaw
282
+ pitch_signal = torch.zeros(1, frame_end, 1).to("cuda") + pose_pitch
283
+ roll_signal = torch.zeros(1, frame_end, 1).to("cuda") + pose_roll
284
+ pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
285
+
286
+ pose_signal = torch.clamp(pose_signal, -1, 1)
287
+
288
+ face_location_signal = torch.zeros(1, frame_end, 1).to("cuda") + face_location
289
+ face_scale_tensor = torch.zeros(1, frame_end, 1).to("cuda") + face_scale
290
+ # ===========================================
291
+ start_time = time.time()
292
+ # ======Diffusion De-nosing Process=========
293
+ generated_directions = model.render(
294
+ one_shot_lia_start,
295
+ one_shot_lia_direction,
296
+ audio_driven,
297
+ face_location_signal,
298
+ face_scale_tensor,
299
+ pose_signal,
300
+ noisy_t,
301
+ step_t,
302
+ True,
303
+ )
304
+ # =========================================
305
+
306
+ execution_time = time.time() - start_time
307
+ logger.info(f"Motion Diffusion Model: {execution_time:.2f} Seconds")
308
+
309
+ generated_directions = generated_directions.detach().cpu().numpy()
310
+
311
+ start_time = time.time()
312
+ # ======Rendering images frame-by-frame=========
313
+ for pred_index in tqdm(range(generated_directions.shape[1])):
314
+ ori_img_recon = lia.render(
315
+ one_shot_lia_start,
316
+ torch.Tensor(generated_directions[:, pred_index, :]).to("cuda"),
317
+ feats,
318
+ )
319
+ ori_img_recon = ori_img_recon.clamp(-1, 1)
320
+ wav_pred = (ori_img_recon.detach() + 1) / 2
321
+ saved_image(
322
+ wav_pred, os.path.join(FRAMES_RESULT_SAVED_PATH, f"{pred_index:06d}.png")
323
+ )
324
+ # ==============================================
325
+
326
+ execution_time = time.time() - start_time
327
+ logger.info(f"Renderer Model: {execution_time:.2f} Seconds")
328
+ logger.info(f"Saving video at {predicted_video_256_path}")
329
+
330
+ frames_to_video(
331
+ str(FRAMES_RESULT_SAVED_PATH),
332
+ test_audio_path,
333
+ str(predicted_video_256_path),
334
+ )
335
+
336
+ shutil.rmtree(FRAMES_RESULT_SAVED_PATH)
337
+
338
+ # Enhancer
339
+ if face_sr and check_package_installed("gfpgan"):
340
+ from imageio import mimsave
341
+
342
+ from visualizr.face_sr.face_enhancer import enhancer_list
343
+
344
+ # Super-resolution
345
+ mimsave(
346
+ predicted_video_512_path / TMP_MP4,
347
+ enhancer_list(predicted_video_256_path, bg_upsampler=None),
348
+ fps=25.0,
349
+ )
350
+
351
+ # Merge audio and video
352
+ video_clip = VideoFileClip(predicted_video_512_path / TMP_MP4)
353
+ audio_clip = AudioFileClip(predicted_video_256_path)
354
+ final_clip = video_clip.set_audio(audio_clip)
355
+ final_clip.write_videofile(
356
+ predicted_video_512_path, codec="libx264", audio_codec="aac"
357
+ )
358
+
359
+ os.remove(predicted_video_512_path / TMP_MP4)
360
+
361
+ if face_sr:
362
+ return predicted_video_256_path, predicted_video_512_path
363
+ return predicted_video_256_path, predicted_video_256_path
364
+
365
+
366
+ @spaces.GPU(duration=300)
367
+ def generate_video(
368
+ uploaded_img: str,
369
+ uploaded_audio: str,
370
+ infer_type: Literal[
371
+ "mfcc_full_control",
372
+ "mfcc_pose_only",
373
+ "hubert_pose_only",
374
+ "hubert_audio_only",
375
+ "hubert_full_control",
376
+ ],
377
+ pose_yaw: float,
378
+ pose_pitch: float,
379
+ pose_roll: float,
380
+ face_location: float,
381
+ face_scale: float,
382
+ step_t: int,
383
+ face_sr: bool,
384
+ seed: int,
385
+ ):
386
+ if not uploaded_img or not uploaded_audio:
387
+ return None, Markdown(
388
+ "Error: Input image or audio file is empty. "
389
+ + "Please check and upload both files."
390
+ )
391
+ try:
392
+ output_256_video_path, output_512_video_path = main(
393
+ infer_type,
394
+ uploaded_img,
395
+ uploaded_audio,
396
+ face_sr,
397
+ pose_yaw,
398
+ pose_pitch,
399
+ pose_roll,
400
+ face_location,
401
+ face_scale,
402
+ step_t,
403
+ seed,
404
+ model_mapping.get(
405
+ infer_type,
406
+ "default_checkpoint.ckpt",
407
+ ),
408
+ )
409
+
410
+ if not os.path.exists(output_256_video_path):
411
+ return None, gr.Markdown(
412
+ "Error: Video generation failed. "
413
+ + "Please check your inputs and try again."
414
+ )
415
+ if output_256_video_path == output_512_video_path:
416
+ return (
417
+ gr.Video(value=output_256_video_path),
418
+ None,
419
+ gr.Markdown("Video (256*256 only) generated successfully!"),
420
+ )
421
+ return (
422
+ gr.Video(value=output_256_video_path),
423
+ gr.Video(value=output_512_video_path),
424
+ gr.Markdown("Video generated successfully!"),
425
+ )
426
+
427
+ except Exception as e:
428
+ return (
429
+ None,
430
+ None,
431
+ gr.Markdown(f"Error: An unexpected error occurred - {str(e)}"),
432
+ )
uv.lock ADDED
The diff for this file is too large to render. See raw diff