torinriley commited on
Commit
ef6c3c2
·
verified ·
1 Parent(s): cb921a0

Upload 26 files

Browse files
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ from PIL import Image
6
+ import os
7
+ from huggingface_hub import hf_hub_download
8
+ from pathlib import Path
9
+ import sys
10
+
11
+ # Add src directory to Python path
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+
14
+ from src import model_loader
15
+ from src import pipeline
16
+ from src.config import Config, DeviceConfig
17
+ from transformers import CLIPTokenizer
18
+
19
+ # Create data directory if it doesn't exist
20
+ data_dir = Path("data")
21
+ data_dir.mkdir(exist_ok=True)
22
+
23
+ # Model configuration
24
+ MODEL_REPO = "stable-diffusion-v1-5/stable-diffusion-v1-5"
25
+ MODEL_FILENAME = "v1-5-pruned-emaonly.ckpt"
26
+ model_file = data_dir / MODEL_FILENAME
27
+
28
+ # Download model if it doesn't exist
29
+ if not model_file.exists():
30
+ print(f"Downloading model from {MODEL_REPO}...")
31
+ model_file = hf_hub_download(
32
+ repo_id=MODEL_REPO,
33
+ filename=MODEL_FILENAME,
34
+ local_dir=data_dir,
35
+ local_dir_use_symlinks=False
36
+ )
37
+ print("Model downloaded successfully!")
38
+
39
+ # Device configuration
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Using device: {device}")
42
+
43
+ # Initialize configuration
44
+ config = Config(
45
+ device=DeviceConfig(device=device),
46
+ tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
47
+ )
48
+
49
+ # Load models with SE blocks enabled
50
+ config.models = model_loader.load_models(str(model_file), device, use_se=True)
51
+
52
+ MAX_SEED = np.iinfo(np.int32).max
53
+ MAX_IMAGE_SIZE = 1024
54
+
55
+ def infer(
56
+ prompt,
57
+ negative_prompt,
58
+ seed,
59
+ randomize_seed,
60
+ width,
61
+ height,
62
+ guidance_scale,
63
+ num_inference_steps,
64
+ progress=gr.Progress(track_tqdm=True),
65
+ ):
66
+ if randomize_seed:
67
+ seed = random.randint(0, MAX_SEED)
68
+
69
+ # Update config with user settings
70
+ config.seed = seed
71
+ config.diffusion.cfg_scale = guidance_scale
72
+ config.diffusion.n_inference_steps = num_inference_steps
73
+ config.model.width = width
74
+ config.model.height = height
75
+
76
+ # Generate image
77
+ output_image = pipeline.generate(
78
+ prompt=prompt,
79
+ uncond_prompt=negative_prompt,
80
+ config=config
81
+ )
82
+
83
+ # Convert numpy array to PIL Image
84
+ image = Image.fromarray(output_image)
85
+
86
+ return image, seed
87
+
88
+ examples = [
89
+ "A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars",
90
+ "A serene mountain landscape at sunset with snow-capped peaks and a clear lake reflection",
91
+ "A detailed portrait of a cyberpunk character with glowing neon implants and holographic tattoos",
92
+ ]
93
+
94
+ css = """
95
+ #col-container {
96
+ margin: 0 auto;
97
+ max-width: 640px;
98
+ }
99
+ """
100
+
101
+ with gr.Blocks(css=css) as demo:
102
+ with gr.Column(elem_id="col-container"):
103
+ gr.Markdown(" # Custom Diffusion Model Text-to-Image Generator")
104
+
105
+ with gr.Row():
106
+ prompt = gr.Text(
107
+ label="Prompt",
108
+ show_label=False,
109
+ max_lines=1,
110
+ placeholder="Enter your prompt",
111
+ container=False,
112
+ )
113
+
114
+ run_button = gr.Button("Run", scale=0, variant="primary")
115
+
116
+ result = gr.Image(label="Result", show_label=False)
117
+
118
+ with gr.Accordion("Advanced Settings", open=False):
119
+ negative_prompt = gr.Text(
120
+ label="Negative prompt",
121
+ max_lines=1,
122
+ placeholder="Enter a negative prompt",
123
+ visible=False,
124
+ )
125
+
126
+ seed = gr.Slider(
127
+ label="Seed",
128
+ minimum=0,
129
+ maximum=MAX_SEED,
130
+ step=1,
131
+ value=42,
132
+ )
133
+
134
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
135
+
136
+ with gr.Row():
137
+ width = gr.Slider(
138
+ label="Width",
139
+ minimum=256,
140
+ maximum=MAX_IMAGE_SIZE,
141
+ step=32,
142
+ value=512,
143
+ )
144
+
145
+ height = gr.Slider(
146
+ label="Height",
147
+ minimum=256,
148
+ maximum=MAX_IMAGE_SIZE,
149
+ step=32,
150
+ value=512,
151
+ )
152
+
153
+ with gr.Row():
154
+ guidance_scale = gr.Slider(
155
+ label="Guidance scale",
156
+ minimum=0.0,
157
+ maximum=10.0,
158
+ step=0.1,
159
+ value=7.5,
160
+ )
161
+
162
+ num_inference_steps = gr.Slider(
163
+ label="Number of inference steps",
164
+ minimum=1,
165
+ maximum=50,
166
+ step=1,
167
+ value=50,
168
+ )
169
+
170
+ gr.Examples(examples=examples, inputs=[prompt])
171
+
172
+ gr.on(
173
+ triggers=[run_button.click, prompt.submit],
174
+ fn=infer,
175
+ inputs=[
176
+ prompt,
177
+ negative_prompt,
178
+ seed,
179
+ randomize_seed,
180
+ width,
181
+ height,
182
+ guidance_scale,
183
+ num_inference_steps,
184
+ ],
185
+ outputs=[result, seed],
186
+ )
187
+
188
+ if __name__ == "__main__":
189
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=4.0.0
3
+ transformers>=4.30.0
4
+ numpy>=1.24.0
5
+ Pillow>=10.0.0
6
+ huggingface_hub>=0.19.0
7
+ accelerate>=0.25.0
8
+ safetensors>=0.4.0
9
+ setuptools>=65.5.1
setup.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="custom-diffusion",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "torch>=2.0.0",
9
+ "gradio>=4.0.0",
10
+ "transformers>=4.30.0",
11
+ "numpy>=1.24.0",
12
+ "Pillow>=10.0.0",
13
+ "huggingface_hub>=0.19.0",
14
+ "accelerate>=0.25.0",
15
+ "safetensors>=0.4.0",
16
+ ],
17
+ )
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file makes the src directory a Python package
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (196 Bytes). View file
 
src/__pycache__/attention.cpython-312.pyc ADDED
Binary file (4.69 kB). View file
 
src/__pycache__/clip.cpython-312.pyc ADDED
Binary file (4.02 kB). View file
 
src/__pycache__/config.cpython-312.pyc ADDED
Binary file (3.4 kB). View file
 
src/__pycache__/ddpm.cpython-312.pyc ADDED
Binary file (6.46 kB). View file
 
src/__pycache__/decoder.cpython-312.pyc ADDED
Binary file (4.93 kB). View file
 
src/__pycache__/diffusion.cpython-312.pyc ADDED
Binary file (14.2 kB). View file
 
src/__pycache__/encoder.cpython-312.pyc ADDED
Binary file (2.56 kB). View file
 
src/__pycache__/model_converter.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc31a7458a7d5afc6251204fd5949d56297f0e0bc97b6b307d2d70b3e2b38d97
3
+ size 170127
src/__pycache__/model_loader.cpython-312.pyc ADDED
Binary file (1.86 kB). View file
 
src/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (8.11 kB). View file
 
src/attention.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import math
5
+
6
+ class SelfAttention(nn.Module):
7
+ def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
8
+ super().__init__()
9
+ self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
10
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
11
+ self.n_heads = n_heads
12
+ self.d_head = d_embed // n_heads
13
+
14
+ def forward(self, x, causal_mask=False):
15
+ input_shape = x.shape
16
+ batch_size, sequence_length, d_embed = input_shape
17
+ interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
18
+
19
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
20
+ q = q.view(interim_shape).transpose(1, 2)
21
+ k = k.view(interim_shape).transpose(1, 2)
22
+ v = v.view(interim_shape).transpose(1, 2)
23
+
24
+ weight = q @ k.transpose(-1, -2)
25
+
26
+ if causal_mask:
27
+ mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
28
+ weight.masked_fill_(mask, -torch.inf)
29
+
30
+ weight /= math.sqrt(self.d_head)
31
+ weight = F.softmax(weight, dim=-1)
32
+ output = weight @ v
33
+ output = output.transpose(1, 2).reshape(input_shape)
34
+ output = self.out_proj(output)
35
+
36
+ return output
37
+
38
+ class CrossAttention(nn.Module):
39
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
40
+ super().__init__()
41
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
42
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
43
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
44
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
45
+ self.n_heads = n_heads
46
+ self.d_head = d_embed // n_heads
47
+
48
+ def forward(self, x, y):
49
+ input_shape = x.shape
50
+ batch_size, sequence_length, d_embed = input_shape
51
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
52
+
53
+ q = self.q_proj(x)
54
+ k = self.k_proj(y)
55
+ v = self.v_proj(y)
56
+
57
+ q = q.view(interim_shape).transpose(1, 2)
58
+ k = k.view(interim_shape).transpose(1, 2)
59
+ v = v.view(interim_shape).transpose(1, 2)
60
+
61
+ weight = q @ k.transpose(-1, -2)
62
+ weight /= math.sqrt(self.d_head)
63
+ weight = F.softmax(weight, dim=-1)
64
+ output = weight @ v
65
+ output = output.transpose(1, 2).contiguous().view(input_shape)
66
+ output = self.out_proj(output)
67
+
68
+ return output
69
+
src/clip.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from .attention import SelfAttention
5
+
6
+ class CLIPEmbedding(nn.Module):
7
+ def __init__(self, n_vocab: int, n_embd: int, n_token: int):
8
+ super().__init__()
9
+ self.token_embedding = nn.Embedding(n_vocab, n_embd)
10
+ self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
11
+
12
+ def forward(self, tokens):
13
+ x = self.token_embedding(tokens)
14
+ x += self.position_embedding
15
+ return x
16
+
17
+ class CLIPLayer(nn.Module):
18
+ def __init__(self, n_head: int, n_embd: int):
19
+ super().__init__()
20
+ self.layernorm_1 = nn.LayerNorm(n_embd)
21
+ self.attention = SelfAttention(n_head, n_embd)
22
+ self.layernorm_2 = nn.LayerNorm(n_embd)
23
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
24
+ self.linear_2 = nn.Linear(4 * n_embd, n_embd)
25
+ self.activation = nn.GELU()
26
+
27
+ def forward(self, x):
28
+ residue = x
29
+ x = self.layernorm_1(x)
30
+ x = self.attention(x, causal_mask=True)
31
+ x += residue
32
+
33
+ residue = x
34
+ x = self.layernorm_2(x)
35
+ x = self.linear_1(x)
36
+ x = self.activation(x)
37
+ x = self.linear_2(x)
38
+ x += residue
39
+
40
+ return x
41
+
42
+ class CLIP(nn.Module):
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.embedding = CLIPEmbedding(49408, 768, 77)
46
+ self.layers = nn.ModuleList([CLIPLayer(12, 768) for _ in range(12)])
47
+ self.layernorm = nn.LayerNorm(768)
48
+
49
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
50
+ state = self.embedding(tokens)
51
+ for layer in self.layers:
52
+ state = layer(state)
53
+ output = self.layernorm(state)
54
+ return output
src/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, Dict, Any
3
+ import torch
4
+
5
+ @dataclass
6
+ class ModelConfig:
7
+ # Image dimensions
8
+ width: int = 512
9
+ height: int = 512
10
+ latents_width: int = 64 # width // 8
11
+ latents_height: int = 64 # height // 8
12
+
13
+ # Model architecture parameters
14
+ n_embd: int = 1280
15
+ n_head: int = 8
16
+ d_context: int = 768
17
+
18
+ # UNet parameters
19
+ n_time: int = 1280
20
+ n_channels: int = 4
21
+ n_residual_blocks: int = 2
22
+
23
+ # Attention parameters
24
+ attention_heads: int = 8
25
+ attention_dim: int = 1280
26
+
27
+ @dataclass
28
+ class DiffusionConfig:
29
+ # Sampling parameters
30
+ n_inference_steps: int = 50
31
+ guidance_scale: float = 7.5
32
+ strength: float = 0.8
33
+
34
+ # Sampler configuration
35
+ sampler_name: str = "ddpm"
36
+ beta_start: float = 0.00085
37
+ beta_end: float = 0.0120
38
+ beta_schedule: str = "linear"
39
+
40
+ # Conditioning parameters
41
+ do_cfg: bool = True
42
+ cfg_scale: float = 7.5
43
+
44
+ @dataclass
45
+ class DeviceConfig:
46
+ device: Optional[str] = None
47
+ idle_device: Optional[str] = None
48
+
49
+ def __post_init__(self):
50
+ if self.device is None:
51
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ if self.idle_device is None:
53
+ self.idle_device = "cpu"
54
+
55
+ @dataclass
56
+ class Config:
57
+ model: ModelConfig = field(default_factory=ModelConfig)
58
+ diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
59
+ device: DeviceConfig = field(default_factory=DeviceConfig)
60
+
61
+ # Additional settings
62
+ seed: Optional[int] = None
63
+ tokenizer: Optional[Any] = None
64
+ models: Dict[str, Any] = field(default_factory=dict)
65
+
66
+ def __post_init__(self):
67
+ # Update latent dimensions based on image dimensions
68
+ self.model.latents_width = self.model.width // 8
69
+ self.model.latents_height = self.model.height // 8
70
+
71
+ # Default configuration instance
72
+ default_config = Config()
src/ddpm.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class DDPMSampler:
5
+
6
+ def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
7
+ self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
8
+ self.alphas = 1.0 - self.betas
9
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
10
+ self.one = torch.tensor(1.0)
11
+
12
+ self.generator = generator
13
+ self.num_train_timesteps = num_training_steps
14
+ self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
15
+
16
+ def set_inference_timesteps(self, num_inference_steps=50):
17
+ self.num_inference_steps = num_inference_steps
18
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
19
+ inference_timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
20
+ self.timesteps = torch.from_numpy(inference_timesteps)
21
+
22
+ def _get_previous_timestep(self, timestep: int) -> int:
23
+ return timestep - self.num_train_timesteps // self.num_inference_steps
24
+
25
+ def _get_variance(self, timestep: int) -> torch.Tensor:
26
+ prev_timestep = self._get_previous_timestep(timestep)
27
+ alpha_prod_t = self.alphas_cumprod[timestep]
28
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
29
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
30
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
31
+ return torch.clamp(variance, min=1e-20)
32
+
33
+ def set_strength(self, strength=1):
34
+ start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
35
+ self.timesteps = self.timesteps[start_step:]
36
+ self.start_step = start_step
37
+
38
+ def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
39
+ prev_timestep = self._get_previous_timestep(timestep)
40
+ alpha_prod_t = self.alphas_cumprod[timestep]
41
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
42
+ beta_prod_t = 1 - alpha_prod_t
43
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
44
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
45
+ current_beta_t = 1 - current_alpha_t
46
+
47
+ pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
48
+ pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
49
+ current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
50
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
51
+
52
+ variance = 0
53
+ if timestep > 0:
54
+ device = model_output.device
55
+ noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
56
+ variance = (self._get_variance(timestep) ** 0.5) * noise
57
+
58
+ return pred_prev_sample + variance
59
+
60
+ def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
61
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
62
+ timesteps = timesteps.to(original_samples.device)
63
+
64
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
65
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
66
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
67
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
68
+
69
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
70
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
71
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
72
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
73
+
74
+ noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
75
+ return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
76
+
src/decoder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .attention import SelfAttention
5
+
6
+ class VAE_AttentionBlock(nn.Module):
7
+ def __init__(self, channels):
8
+ super().__init__()
9
+ self.groupnorm = nn.GroupNorm(32, channels)
10
+ self.attention = SelfAttention(1, channels)
11
+
12
+ def forward(self, x):
13
+ residue = x
14
+ x = self.groupnorm(x)
15
+ n, c, h, w = x.shape
16
+ x = x.view((n, c, h * w)).transpose(-1, -2)
17
+ x = self.attention(x)
18
+ x = x.transpose(-1, -2).view((n, c, h, w))
19
+ return x + residue
20
+
21
+ class VAE_ResidualBlock(nn.Module):
22
+ def __init__(self, in_channels, out_channels):
23
+ super().__init__()
24
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
25
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
26
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
27
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
28
+ self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
29
+
30
+ def forward(self, x):
31
+ residue = x
32
+ x = self.groupnorm_1(x)
33
+ x = F.silu(x)
34
+ x = self.conv_1(x)
35
+ x = self.groupnorm_2(x)
36
+ x = F.silu(x)
37
+ x = self.conv_2(x)
38
+ return x + self.residual_layer(residue)
39
+
40
+ class VAE_Decoder(nn.Sequential):
41
+ def __init__(self):
42
+ super().__init__(
43
+ nn.Conv2d(4, 4, kernel_size=1, padding=0),
44
+ nn.Conv2d(4, 512, kernel_size=3, padding=1),
45
+ VAE_ResidualBlock(512, 512),
46
+ VAE_AttentionBlock(512),
47
+ VAE_ResidualBlock(512, 512),
48
+ VAE_ResidualBlock(512, 512),
49
+ VAE_ResidualBlock(512, 512),
50
+ VAE_ResidualBlock(512, 512),
51
+ nn.Upsample(scale_factor=2),
52
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
53
+ VAE_ResidualBlock(512, 512),
54
+ VAE_ResidualBlock(512, 512),
55
+ VAE_ResidualBlock(512, 512),
56
+ nn.Upsample(scale_factor=2),
57
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
58
+ VAE_ResidualBlock(512, 256),
59
+ VAE_ResidualBlock(256, 256),
60
+ VAE_ResidualBlock(256, 256),
61
+ nn.Upsample(scale_factor=2),
62
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
63
+ VAE_ResidualBlock(256, 128),
64
+ VAE_ResidualBlock(128, 128),
65
+ VAE_ResidualBlock(128, 128),
66
+ nn.GroupNorm(32, 128),
67
+ nn.SiLU(),
68
+ nn.Conv2d(128, 3, kernel_size=3, padding=1),
69
+ )
70
+
71
+ def forward(self, x):
72
+ x /= 0.18215
73
+ for module in self:
74
+ x = module(x)
75
+ return x
76
+
src/demo.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model_loader
2
+ import pipeline
3
+ from PIL import Image
4
+ from pathlib import Path
5
+ from transformers import CLIPTokenizer
6
+ import torch
7
+ from config import Config, default_config, DeviceConfig
8
+
9
+ # Device configuration
10
+ ALLOW_CUDA = False
11
+ ALLOW_MPS = False
12
+
13
+ device = "cpu"
14
+ if torch.cuda.is_available() and ALLOW_CUDA:
15
+ device = "cuda"
16
+ elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
17
+ device = "mps"
18
+ print(f"Using device: {device}")
19
+
20
+ # Initialize configuration
21
+ config = Config(
22
+ device=DeviceConfig(device=device),
23
+ seed=42,
24
+ tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
25
+ )
26
+
27
+ # Update diffusion parameters
28
+ config.diffusion.strength = 0.75
29
+ config.diffusion.cfg_scale = 8.0
30
+ config.diffusion.n_inference_steps = 50
31
+
32
+ # Load models with SE blocks enabled
33
+ model_file = "data/v1-5-pruned-emaonly.ckpt"
34
+ config.models = model_loader.load_models(model_file, device, use_se=True)
35
+
36
+ # Generate image
37
+ prompt = "A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars"
38
+ uncond_prompt = ""
39
+
40
+ output_image = pipeline.generate(
41
+ prompt=prompt,
42
+ uncond_prompt=uncond_prompt,
43
+ config=config
44
+ )
45
+
46
+ # Save output
47
+ output_image = Image.fromarray(output_image)
48
+ output_image.save("output.png")
src/diffusion.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .attention import SelfAttention, CrossAttention
5
+
6
+ class TimeEmbedding(nn.Module):
7
+ def __init__(self, n_embd):
8
+ super().__init__()
9
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
10
+ self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
11
+
12
+ def forward(self, x):
13
+ x = F.silu(self.linear_1(x))
14
+ return self.linear_2(x)
15
+
16
+ class SqueezeExcitation(nn.Module):
17
+ def __init__(self, channels, reduction=16):
18
+ super().__init__()
19
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
20
+ self.fc = nn.Sequential(
21
+ nn.Linear(channels, channels // reduction, bias=False),
22
+ nn.ReLU(inplace=True),
23
+ nn.Linear(channels // reduction, channels, bias=False),
24
+ nn.Sigmoid()
25
+ )
26
+
27
+ def forward(self, x):
28
+ b, c, _, _ = x.size()
29
+ y = self.avg_pool(x).view(b, c)
30
+ y = self.fc(y).view(b, c, 1, 1)
31
+ return x * y.expand_as(x)
32
+
33
+ class UNET_ResidualBlock(nn.Module):
34
+ def __init__(self, in_channels, out_channels, n_time=1280, use_se=False):
35
+ super().__init__()
36
+ self.groupnorm_feature = nn.GroupNorm(32, in_channels)
37
+ self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
38
+ self.linear_time = nn.Linear(n_time, out_channels)
39
+ self.groupnorm_merged = nn.GroupNorm(32, out_channels)
40
+ self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
41
+ self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
42
+
43
+ # Add Squeeze-Excitation blocks only if use_se is True
44
+ self.use_se = use_se
45
+ if use_se:
46
+ self.se1 = SqueezeExcitation(out_channels)
47
+ self.se2 = SqueezeExcitation(out_channels)
48
+
49
+ def forward(self, feature, time):
50
+ residue = feature
51
+ feature = F.silu(self.groupnorm_feature(feature))
52
+ feature = self.conv_feature(feature)
53
+ if self.use_se:
54
+ feature = self.se1(feature) # Apply SE after first conv
55
+
56
+ time = self.linear_time(F.silu(time))
57
+ merged = feature + time.unsqueeze(-1).unsqueeze(-1)
58
+ merged = F.silu(self.groupnorm_merged(merged))
59
+ merged = self.conv_merged(merged)
60
+ if self.use_se:
61
+ merged = self.se2(merged) # Apply SE after second conv
62
+
63
+ return merged + self.residual_layer(residue)
64
+
65
+ class UNET_AttentionBlock(nn.Module):
66
+ def __init__(self, n_head: int, n_embd: int, d_context=768):
67
+ super().__init__()
68
+ channels = n_head * n_embd
69
+ self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
70
+ self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
71
+ self.layernorm_1 = nn.LayerNorm(channels)
72
+ self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
73
+ self.layernorm_2 = nn.LayerNorm(channels)
74
+ self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
75
+ self.layernorm_3 = nn.LayerNorm(channels)
76
+ self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
77
+ self.linear_geglu_2 = nn.Linear(4 * channels, channels)
78
+ self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
79
+
80
+ def forward(self, x, context):
81
+ residue_long = x
82
+ x = self.conv_input(self.groupnorm(x))
83
+ n, c, h, w = x.shape
84
+ x = x.view((n, c, h * w)).transpose(-1, -2)
85
+ residue_short = x
86
+ x = self.attention_1(self.layernorm_1(x)) + residue_short
87
+ residue_short = x
88
+ x = self.attention_2(self.layernorm_2(x), context) + residue_short
89
+ residue_short = x
90
+ x, gate = self.linear_geglu_1(self.layernorm_3(x)).chunk(2, dim=-1)
91
+ x = self.linear_geglu_2(x * F.gelu(gate)) + residue_short
92
+ x = x.transpose(-1, -2).view((n, c, h, w))
93
+ return self.conv_output(x) + residue_long
94
+
95
+ class Upsample(nn.Module):
96
+ def __init__(self, channels):
97
+ super().__init__()
98
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
99
+
100
+ def forward(self, x):
101
+ return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
102
+
103
+ class SwitchSequential(nn.Sequential):
104
+ def forward(self, x, context, time):
105
+ for layer in self:
106
+ if isinstance(layer, UNET_AttentionBlock):
107
+ x = layer(x, context)
108
+ elif isinstance(layer, UNET_ResidualBlock):
109
+ x = layer(x, time)
110
+ else:
111
+ x = layer(x)
112
+ return x
113
+
114
+ class UNET(nn.Module):
115
+ def __init__(self, use_se=False):
116
+ super().__init__()
117
+ self.encoders = nn.ModuleList([
118
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
119
+ SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
120
+ SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
121
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
122
+ SwitchSequential(UNET_ResidualBlock(320, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
123
+ SwitchSequential(UNET_ResidualBlock(640, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
124
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
125
+ SwitchSequential(UNET_ResidualBlock(640, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
126
+ SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
127
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
128
+ SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
129
+ SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
130
+ ])
131
+
132
+ self.bottleneck = SwitchSequential(
133
+ UNET_ResidualBlock(1280, 1280, use_se=use_se),
134
+ UNET_AttentionBlock(8, 160),
135
+ UNET_ResidualBlock(1280, 1280, use_se=use_se),
136
+ )
137
+
138
+ self.decoders = nn.ModuleList([
139
+ SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
140
+ SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
141
+ SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), Upsample(1280)),
142
+ SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
143
+ SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
144
+ SwitchSequential(UNET_ResidualBlock(1920, 1280, use_se=use_se), UNET_AttentionBlock(8, 160), Upsample(1280)),
145
+ SwitchSequential(UNET_ResidualBlock(1920, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
146
+ SwitchSequential(UNET_ResidualBlock(1280, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
147
+ SwitchSequential(UNET_ResidualBlock(960, 640, use_se=use_se), UNET_AttentionBlock(8, 80), Upsample(640)),
148
+ SwitchSequential(UNET_ResidualBlock(960, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
149
+ SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
150
+ SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
151
+ ])
152
+
153
+ def forward(self, x, context, time):
154
+ skip_connections = []
155
+ for layers in self.encoders:
156
+ x = layers(x, context, time)
157
+ skip_connections.append(x)
158
+
159
+ x = self.bottleneck(x, context, time)
160
+
161
+ for layers in self.decoders:
162
+ x = torch.cat((x, skip_connections.pop()), dim=1)
163
+ x = layers(x, context, time)
164
+
165
+ return x
166
+
167
+ class UNET_OutputLayer(nn.Module):
168
+ def __init__(self, in_channels, out_channels):
169
+ super().__init__()
170
+ self.groupnorm = nn.GroupNorm(32, in_channels)
171
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
172
+
173
+ def forward(self, x):
174
+ x = F.silu(self.groupnorm(x))
175
+ return self.conv(x)
176
+
177
+ class Diffusion(nn.Module):
178
+ def __init__(self, use_se=False):
179
+ super().__init__()
180
+ self.time_embedding = TimeEmbedding(320)
181
+ self.unet = UNET(use_se=use_se)
182
+ self.final = UNET_OutputLayer(320, 4)
183
+
184
+ def forward(self, latent, context, time):
185
+ time = self.time_embedding(time)
186
+ output = self.unet(latent, context, time)
187
+ return self.final(output)
src/encoder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
+
6
+ class VAE_Encoder(nn.Sequential):
7
+ def __init__(self):
8
+ super().__init__(
9
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
10
+ VAE_ResidualBlock(128, 128),
11
+ VAE_ResidualBlock(128, 128),
12
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
13
+ VAE_ResidualBlock(128, 256),
14
+ VAE_ResidualBlock(256, 256),
15
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
16
+ VAE_ResidualBlock(256, 512),
17
+ VAE_ResidualBlock(512, 512),
18
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
19
+ VAE_ResidualBlock(512, 512),
20
+ VAE_ResidualBlock(512, 512),
21
+ VAE_ResidualBlock(512, 512),
22
+ VAE_AttentionBlock(512),
23
+ VAE_ResidualBlock(512, 512),
24
+ nn.GroupNorm(32, 512),
25
+ nn.SiLU(),
26
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
27
+ nn.Conv2d(8, 8, kernel_size=1, padding=0),
28
+ )
29
+
30
+ def forward(self, x, noise):
31
+ for module in self:
32
+ if getattr(module, 'stride', None) == (2, 2):
33
+ x = F.pad(x, (0, 1, 0, 1))
34
+ x = module(x)
35
+ mean, log_variance = torch.chunk(x, 2, dim=1)
36
+ log_variance = torch.clamp(log_variance, -30, 20)
37
+ variance = log_variance.exp()
38
+ stdev = variance.sqrt()
39
+ x = mean + stdev * noise
40
+ x *= 0.18215
41
+ return x
42
+
src/model_converter.py ADDED
The diff for this file is too large to render. See raw diff
 
src/model_loader.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .clip import CLIP
2
+ from .encoder import VAE_Encoder
3
+ from .decoder import VAE_Decoder
4
+ from .diffusion import Diffusion
5
+
6
+ from . import model_converter
7
+ import torch
8
+
9
+ def load_models(ckpt_path, device, use_se=False):
10
+ state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
11
+
12
+ encoder = VAE_Encoder().to(device)
13
+ encoder.load_state_dict(state_dict['encoder'], strict=True)
14
+
15
+ decoder = VAE_Decoder().to(device)
16
+ decoder.load_state_dict(state_dict['decoder'], strict=True)
17
+
18
+ # Initialize diffusion model with SE blocks disabled for loading pre-trained weights
19
+ diffusion = Diffusion(use_se=False).to(device)
20
+ diffusion.load_state_dict(state_dict['diffusion'], strict=True)
21
+
22
+ # If SE blocks are requested, reinitialize the model with them
23
+ if use_se:
24
+ diffusion = Diffusion(use_se=True).to(device)
25
+ # Copy the weights from the loaded model
26
+ with torch.no_grad():
27
+ for name, param in diffusion.named_parameters():
28
+ if 'se' not in name: # Skip SE block parameters
29
+ if name in state_dict['diffusion']:
30
+ param.copy_(state_dict['diffusion'][name])
31
+
32
+ clip = CLIP().to(device)
33
+ clip.load_state_dict(state_dict['clip'], strict=True)
34
+
35
+ return {
36
+ 'clip': clip,
37
+ 'encoder': encoder,
38
+ 'decoder': decoder,
39
+ 'diffusion': diffusion,
40
+ }
src/pipeline.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from .ddpm import DDPMSampler
6
+ import logging
7
+ from .config import Config, default_config
8
+
9
+ WIDTH = 512
10
+ HEIGHT = 512
11
+ LATENTS_WIDTH = WIDTH // 8
12
+ LATENTS_HEIGHT = HEIGHT // 8
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ def generate(
17
+ prompt,
18
+ uncond_prompt=None,
19
+ input_image=None,
20
+ config: Config = default_config,
21
+ ):
22
+ with torch.no_grad():
23
+ validate_strength(config.diffusion.strength)
24
+ generator = initialize_generator(config.seed, config.device.device)
25
+ context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, config.tokenizer, config.models["clip"], config.device.device)
26
+ latents = initialize_latents(input_image, config.diffusion.strength, generator, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps)
27
+ images = run_diffusion(latents, context, config.diffusion.do_cfg, config.diffusion.cfg_scale, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps, generator)
28
+ return postprocess_images(images)
29
+
30
+ def validate_strength(strength):
31
+ if not 0 < strength <= 1:
32
+ raise ValueError("Strength must be between 0 and 1")
33
+
34
+ def initialize_generator(seed, device):
35
+ generator = torch.Generator(device=device)
36
+ if seed is None:
37
+ generator.seed()
38
+ else:
39
+ generator.manual_seed(seed)
40
+ return generator
41
+
42
+ def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
43
+ clip.to(device)
44
+ if do_cfg:
45
+ cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
46
+ cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
47
+ cond_context = clip(cond_tokens)
48
+ uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
49
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
50
+ uncond_context = clip(uncond_tokens)
51
+ context = torch.cat([cond_context, uncond_context])
52
+ else:
53
+ tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
54
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device)
55
+ context = clip(tokens)
56
+ return context
57
+
58
+ def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps):
59
+ if input_image is None:
60
+ # Initialize with random noise
61
+ latents = torch.randn((1, 4, 64, 64), generator=generator, device=device)
62
+ else:
63
+ # Initialize with encoded input image
64
+ latents = encode_image(input_image, models, device)
65
+ # Add noise based on strength
66
+ noise = torch.randn_like(latents, generator=generator)
67
+ latents = (1 - strength) * latents + strength * noise
68
+ return latents
69
+
70
+ def preprocess_image(input_image):
71
+ input_image_tensor = input_image.resize((WIDTH, HEIGHT))
72
+ input_image_tensor = np.array(input_image_tensor)
73
+ input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
74
+ input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
75
+ input_image_tensor = input_image_tensor.unsqueeze(0)
76
+ input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
77
+ return input_image_tensor
78
+
79
+ def get_sampler(sampler_name, generator, n_inference_steps):
80
+ if sampler_name == "ddpm":
81
+ sampler = DDPMSampler(generator)
82
+ sampler.set_inference_timesteps(n_inference_steps)
83
+ else:
84
+ raise ValueError(f"Unknown sampler value {sampler_name}.")
85
+ return sampler
86
+
87
+ def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
88
+ diffusion = models["diffusion"]
89
+ diffusion.to(device)
90
+ sampler = get_sampler(sampler_name, generator, n_inference_steps)
91
+ timesteps = tqdm(sampler.timesteps)
92
+ for timestep in timesteps:
93
+ time_embedding = get_time_embedding(timestep).to(device)
94
+ model_input = latents.repeat(2, 1, 1, 1) if do_cfg else latents
95
+ model_output = diffusion(model_input, context, time_embedding)
96
+ if do_cfg:
97
+ output_cond, output_uncond = model_output.chunk(2)
98
+ model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
99
+ latents = sampler.step(timestep, latents, model_output)
100
+ decoder = models["decoder"]
101
+ decoder.to(device)
102
+ images = decoder(latents)
103
+ return images
104
+
105
+ def postprocess_images(images):
106
+ images = rescale(images, (-1, 1), (0, 255), clamp=True)
107
+ images = images.permute(0, 2, 3, 1)
108
+ images = images.to("cpu", torch.uint8).numpy()
109
+ return images[0]
110
+
111
+ def rescale(x, old_range, new_range, clamp=False):
112
+ old_min, old_max = old_range
113
+ new_min, new_max = new_range
114
+ x -= old_min
115
+ x *= (new_max - new_min) / (old_max - old_min)
116
+ x += new_min
117
+ if clamp:
118
+ x = x.clamp(new_min, new_max)
119
+ return x
120
+
121
+ def get_time_embedding(timestep):
122
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
123
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
124
+ return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)