Spaces:
Paused
Paused
Upload 26 files
Browse files- app.py +189 -0
- requirements.txt +9 -0
- setup.py +17 -0
- src/__init__.py +1 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/attention.cpython-312.pyc +0 -0
- src/__pycache__/clip.cpython-312.pyc +0 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/ddpm.cpython-312.pyc +0 -0
- src/__pycache__/decoder.cpython-312.pyc +0 -0
- src/__pycache__/diffusion.cpython-312.pyc +0 -0
- src/__pycache__/encoder.cpython-312.pyc +0 -0
- src/__pycache__/model_converter.cpython-312.pyc +3 -0
- src/__pycache__/model_loader.cpython-312.pyc +0 -0
- src/__pycache__/pipeline.cpython-312.pyc +0 -0
- src/attention.py +69 -0
- src/clip.py +54 -0
- src/config.py +72 -0
- src/ddpm.py +76 -0
- src/decoder.py +76 -0
- src/demo.py +48 -0
- src/diffusion.py +187 -0
- src/encoder.py +42 -0
- src/model_converter.py +0 -0
- src/model_loader.py +40 -0
- src/pipeline.py +124 -0
app.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from pathlib import Path
|
9 |
+
import sys
|
10 |
+
|
11 |
+
# Add src directory to Python path
|
12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
+
|
14 |
+
from src import model_loader
|
15 |
+
from src import pipeline
|
16 |
+
from src.config import Config, DeviceConfig
|
17 |
+
from transformers import CLIPTokenizer
|
18 |
+
|
19 |
+
# Create data directory if it doesn't exist
|
20 |
+
data_dir = Path("data")
|
21 |
+
data_dir.mkdir(exist_ok=True)
|
22 |
+
|
23 |
+
# Model configuration
|
24 |
+
MODEL_REPO = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
25 |
+
MODEL_FILENAME = "v1-5-pruned-emaonly.ckpt"
|
26 |
+
model_file = data_dir / MODEL_FILENAME
|
27 |
+
|
28 |
+
# Download model if it doesn't exist
|
29 |
+
if not model_file.exists():
|
30 |
+
print(f"Downloading model from {MODEL_REPO}...")
|
31 |
+
model_file = hf_hub_download(
|
32 |
+
repo_id=MODEL_REPO,
|
33 |
+
filename=MODEL_FILENAME,
|
34 |
+
local_dir=data_dir,
|
35 |
+
local_dir_use_symlinks=False
|
36 |
+
)
|
37 |
+
print("Model downloaded successfully!")
|
38 |
+
|
39 |
+
# Device configuration
|
40 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
print(f"Using device: {device}")
|
42 |
+
|
43 |
+
# Initialize configuration
|
44 |
+
config = Config(
|
45 |
+
device=DeviceConfig(device=device),
|
46 |
+
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
47 |
+
)
|
48 |
+
|
49 |
+
# Load models with SE blocks enabled
|
50 |
+
config.models = model_loader.load_models(str(model_file), device, use_se=True)
|
51 |
+
|
52 |
+
MAX_SEED = np.iinfo(np.int32).max
|
53 |
+
MAX_IMAGE_SIZE = 1024
|
54 |
+
|
55 |
+
def infer(
|
56 |
+
prompt,
|
57 |
+
negative_prompt,
|
58 |
+
seed,
|
59 |
+
randomize_seed,
|
60 |
+
width,
|
61 |
+
height,
|
62 |
+
guidance_scale,
|
63 |
+
num_inference_steps,
|
64 |
+
progress=gr.Progress(track_tqdm=True),
|
65 |
+
):
|
66 |
+
if randomize_seed:
|
67 |
+
seed = random.randint(0, MAX_SEED)
|
68 |
+
|
69 |
+
# Update config with user settings
|
70 |
+
config.seed = seed
|
71 |
+
config.diffusion.cfg_scale = guidance_scale
|
72 |
+
config.diffusion.n_inference_steps = num_inference_steps
|
73 |
+
config.model.width = width
|
74 |
+
config.model.height = height
|
75 |
+
|
76 |
+
# Generate image
|
77 |
+
output_image = pipeline.generate(
|
78 |
+
prompt=prompt,
|
79 |
+
uncond_prompt=negative_prompt,
|
80 |
+
config=config
|
81 |
+
)
|
82 |
+
|
83 |
+
# Convert numpy array to PIL Image
|
84 |
+
image = Image.fromarray(output_image)
|
85 |
+
|
86 |
+
return image, seed
|
87 |
+
|
88 |
+
examples = [
|
89 |
+
"A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars",
|
90 |
+
"A serene mountain landscape at sunset with snow-capped peaks and a clear lake reflection",
|
91 |
+
"A detailed portrait of a cyberpunk character with glowing neon implants and holographic tattoos",
|
92 |
+
]
|
93 |
+
|
94 |
+
css = """
|
95 |
+
#col-container {
|
96 |
+
margin: 0 auto;
|
97 |
+
max-width: 640px;
|
98 |
+
}
|
99 |
+
"""
|
100 |
+
|
101 |
+
with gr.Blocks(css=css) as demo:
|
102 |
+
with gr.Column(elem_id="col-container"):
|
103 |
+
gr.Markdown(" # Custom Diffusion Model Text-to-Image Generator")
|
104 |
+
|
105 |
+
with gr.Row():
|
106 |
+
prompt = gr.Text(
|
107 |
+
label="Prompt",
|
108 |
+
show_label=False,
|
109 |
+
max_lines=1,
|
110 |
+
placeholder="Enter your prompt",
|
111 |
+
container=False,
|
112 |
+
)
|
113 |
+
|
114 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
115 |
+
|
116 |
+
result = gr.Image(label="Result", show_label=False)
|
117 |
+
|
118 |
+
with gr.Accordion("Advanced Settings", open=False):
|
119 |
+
negative_prompt = gr.Text(
|
120 |
+
label="Negative prompt",
|
121 |
+
max_lines=1,
|
122 |
+
placeholder="Enter a negative prompt",
|
123 |
+
visible=False,
|
124 |
+
)
|
125 |
+
|
126 |
+
seed = gr.Slider(
|
127 |
+
label="Seed",
|
128 |
+
minimum=0,
|
129 |
+
maximum=MAX_SEED,
|
130 |
+
step=1,
|
131 |
+
value=42,
|
132 |
+
)
|
133 |
+
|
134 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
width = gr.Slider(
|
138 |
+
label="Width",
|
139 |
+
minimum=256,
|
140 |
+
maximum=MAX_IMAGE_SIZE,
|
141 |
+
step=32,
|
142 |
+
value=512,
|
143 |
+
)
|
144 |
+
|
145 |
+
height = gr.Slider(
|
146 |
+
label="Height",
|
147 |
+
minimum=256,
|
148 |
+
maximum=MAX_IMAGE_SIZE,
|
149 |
+
step=32,
|
150 |
+
value=512,
|
151 |
+
)
|
152 |
+
|
153 |
+
with gr.Row():
|
154 |
+
guidance_scale = gr.Slider(
|
155 |
+
label="Guidance scale",
|
156 |
+
minimum=0.0,
|
157 |
+
maximum=10.0,
|
158 |
+
step=0.1,
|
159 |
+
value=7.5,
|
160 |
+
)
|
161 |
+
|
162 |
+
num_inference_steps = gr.Slider(
|
163 |
+
label="Number of inference steps",
|
164 |
+
minimum=1,
|
165 |
+
maximum=50,
|
166 |
+
step=1,
|
167 |
+
value=50,
|
168 |
+
)
|
169 |
+
|
170 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
171 |
+
|
172 |
+
gr.on(
|
173 |
+
triggers=[run_button.click, prompt.submit],
|
174 |
+
fn=infer,
|
175 |
+
inputs=[
|
176 |
+
prompt,
|
177 |
+
negative_prompt,
|
178 |
+
seed,
|
179 |
+
randomize_seed,
|
180 |
+
width,
|
181 |
+
height,
|
182 |
+
guidance_scale,
|
183 |
+
num_inference_steps,
|
184 |
+
],
|
185 |
+
outputs=[result, seed],
|
186 |
+
)
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
gradio>=4.0.0
|
3 |
+
transformers>=4.30.0
|
4 |
+
numpy>=1.24.0
|
5 |
+
Pillow>=10.0.0
|
6 |
+
huggingface_hub>=0.19.0
|
7 |
+
accelerate>=0.25.0
|
8 |
+
safetensors>=0.4.0
|
9 |
+
setuptools>=65.5.1
|
setup.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="custom-diffusion",
|
5 |
+
version="0.1.0",
|
6 |
+
packages=find_packages(),
|
7 |
+
install_requires=[
|
8 |
+
"torch>=2.0.0",
|
9 |
+
"gradio>=4.0.0",
|
10 |
+
"transformers>=4.30.0",
|
11 |
+
"numpy>=1.24.0",
|
12 |
+
"Pillow>=10.0.0",
|
13 |
+
"huggingface_hub>=0.19.0",
|
14 |
+
"accelerate>=0.25.0",
|
15 |
+
"safetensors>=0.4.0",
|
16 |
+
],
|
17 |
+
)
|
src/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# This file makes the src directory a Python package
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (196 Bytes). View file
|
|
src/__pycache__/attention.cpython-312.pyc
ADDED
Binary file (4.69 kB). View file
|
|
src/__pycache__/clip.cpython-312.pyc
ADDED
Binary file (4.02 kB). View file
|
|
src/__pycache__/config.cpython-312.pyc
ADDED
Binary file (3.4 kB). View file
|
|
src/__pycache__/ddpm.cpython-312.pyc
ADDED
Binary file (6.46 kB). View file
|
|
src/__pycache__/decoder.cpython-312.pyc
ADDED
Binary file (4.93 kB). View file
|
|
src/__pycache__/diffusion.cpython-312.pyc
ADDED
Binary file (14.2 kB). View file
|
|
src/__pycache__/encoder.cpython-312.pyc
ADDED
Binary file (2.56 kB). View file
|
|
src/__pycache__/model_converter.cpython-312.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc31a7458a7d5afc6251204fd5949d56297f0e0bc97b6b307d2d70b3e2b38d97
|
3 |
+
size 170127
|
src/__pycache__/model_loader.cpython-312.pyc
ADDED
Binary file (1.86 kB). View file
|
|
src/__pycache__/pipeline.cpython-312.pyc
ADDED
Binary file (8.11 kB). View file
|
|
src/attention.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class SelfAttention(nn.Module):
|
7 |
+
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
|
8 |
+
super().__init__()
|
9 |
+
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
|
10 |
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
11 |
+
self.n_heads = n_heads
|
12 |
+
self.d_head = d_embed // n_heads
|
13 |
+
|
14 |
+
def forward(self, x, causal_mask=False):
|
15 |
+
input_shape = x.shape
|
16 |
+
batch_size, sequence_length, d_embed = input_shape
|
17 |
+
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
|
18 |
+
|
19 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
20 |
+
q = q.view(interim_shape).transpose(1, 2)
|
21 |
+
k = k.view(interim_shape).transpose(1, 2)
|
22 |
+
v = v.view(interim_shape).transpose(1, 2)
|
23 |
+
|
24 |
+
weight = q @ k.transpose(-1, -2)
|
25 |
+
|
26 |
+
if causal_mask:
|
27 |
+
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
|
28 |
+
weight.masked_fill_(mask, -torch.inf)
|
29 |
+
|
30 |
+
weight /= math.sqrt(self.d_head)
|
31 |
+
weight = F.softmax(weight, dim=-1)
|
32 |
+
output = weight @ v
|
33 |
+
output = output.transpose(1, 2).reshape(input_shape)
|
34 |
+
output = self.out_proj(output)
|
35 |
+
|
36 |
+
return output
|
37 |
+
|
38 |
+
class CrossAttention(nn.Module):
|
39 |
+
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
|
40 |
+
super().__init__()
|
41 |
+
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
|
42 |
+
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
43 |
+
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
44 |
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
45 |
+
self.n_heads = n_heads
|
46 |
+
self.d_head = d_embed // n_heads
|
47 |
+
|
48 |
+
def forward(self, x, y):
|
49 |
+
input_shape = x.shape
|
50 |
+
batch_size, sequence_length, d_embed = input_shape
|
51 |
+
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
|
52 |
+
|
53 |
+
q = self.q_proj(x)
|
54 |
+
k = self.k_proj(y)
|
55 |
+
v = self.v_proj(y)
|
56 |
+
|
57 |
+
q = q.view(interim_shape).transpose(1, 2)
|
58 |
+
k = k.view(interim_shape).transpose(1, 2)
|
59 |
+
v = v.view(interim_shape).transpose(1, 2)
|
60 |
+
|
61 |
+
weight = q @ k.transpose(-1, -2)
|
62 |
+
weight /= math.sqrt(self.d_head)
|
63 |
+
weight = F.softmax(weight, dim=-1)
|
64 |
+
output = weight @ v
|
65 |
+
output = output.transpose(1, 2).contiguous().view(input_shape)
|
66 |
+
output = self.out_proj(output)
|
67 |
+
|
68 |
+
return output
|
69 |
+
|
src/clip.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .attention import SelfAttention
|
5 |
+
|
6 |
+
class CLIPEmbedding(nn.Module):
|
7 |
+
def __init__(self, n_vocab: int, n_embd: int, n_token: int):
|
8 |
+
super().__init__()
|
9 |
+
self.token_embedding = nn.Embedding(n_vocab, n_embd)
|
10 |
+
self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
|
11 |
+
|
12 |
+
def forward(self, tokens):
|
13 |
+
x = self.token_embedding(tokens)
|
14 |
+
x += self.position_embedding
|
15 |
+
return x
|
16 |
+
|
17 |
+
class CLIPLayer(nn.Module):
|
18 |
+
def __init__(self, n_head: int, n_embd: int):
|
19 |
+
super().__init__()
|
20 |
+
self.layernorm_1 = nn.LayerNorm(n_embd)
|
21 |
+
self.attention = SelfAttention(n_head, n_embd)
|
22 |
+
self.layernorm_2 = nn.LayerNorm(n_embd)
|
23 |
+
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
24 |
+
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
|
25 |
+
self.activation = nn.GELU()
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
residue = x
|
29 |
+
x = self.layernorm_1(x)
|
30 |
+
x = self.attention(x, causal_mask=True)
|
31 |
+
x += residue
|
32 |
+
|
33 |
+
residue = x
|
34 |
+
x = self.layernorm_2(x)
|
35 |
+
x = self.linear_1(x)
|
36 |
+
x = self.activation(x)
|
37 |
+
x = self.linear_2(x)
|
38 |
+
x += residue
|
39 |
+
|
40 |
+
return x
|
41 |
+
|
42 |
+
class CLIP(nn.Module):
|
43 |
+
def __init__(self):
|
44 |
+
super().__init__()
|
45 |
+
self.embedding = CLIPEmbedding(49408, 768, 77)
|
46 |
+
self.layers = nn.ModuleList([CLIPLayer(12, 768) for _ in range(12)])
|
47 |
+
self.layernorm = nn.LayerNorm(768)
|
48 |
+
|
49 |
+
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
50 |
+
state = self.embedding(tokens)
|
51 |
+
for layer in self.layers:
|
52 |
+
state = layer(state)
|
53 |
+
output = self.layernorm(state)
|
54 |
+
return output
|
src/config.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional, Dict, Any
|
3 |
+
import torch
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ModelConfig:
|
7 |
+
# Image dimensions
|
8 |
+
width: int = 512
|
9 |
+
height: int = 512
|
10 |
+
latents_width: int = 64 # width // 8
|
11 |
+
latents_height: int = 64 # height // 8
|
12 |
+
|
13 |
+
# Model architecture parameters
|
14 |
+
n_embd: int = 1280
|
15 |
+
n_head: int = 8
|
16 |
+
d_context: int = 768
|
17 |
+
|
18 |
+
# UNet parameters
|
19 |
+
n_time: int = 1280
|
20 |
+
n_channels: int = 4
|
21 |
+
n_residual_blocks: int = 2
|
22 |
+
|
23 |
+
# Attention parameters
|
24 |
+
attention_heads: int = 8
|
25 |
+
attention_dim: int = 1280
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class DiffusionConfig:
|
29 |
+
# Sampling parameters
|
30 |
+
n_inference_steps: int = 50
|
31 |
+
guidance_scale: float = 7.5
|
32 |
+
strength: float = 0.8
|
33 |
+
|
34 |
+
# Sampler configuration
|
35 |
+
sampler_name: str = "ddpm"
|
36 |
+
beta_start: float = 0.00085
|
37 |
+
beta_end: float = 0.0120
|
38 |
+
beta_schedule: str = "linear"
|
39 |
+
|
40 |
+
# Conditioning parameters
|
41 |
+
do_cfg: bool = True
|
42 |
+
cfg_scale: float = 7.5
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class DeviceConfig:
|
46 |
+
device: Optional[str] = None
|
47 |
+
idle_device: Optional[str] = None
|
48 |
+
|
49 |
+
def __post_init__(self):
|
50 |
+
if self.device is None:
|
51 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
+
if self.idle_device is None:
|
53 |
+
self.idle_device = "cpu"
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class Config:
|
57 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
58 |
+
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
|
59 |
+
device: DeviceConfig = field(default_factory=DeviceConfig)
|
60 |
+
|
61 |
+
# Additional settings
|
62 |
+
seed: Optional[int] = None
|
63 |
+
tokenizer: Optional[Any] = None
|
64 |
+
models: Dict[str, Any] = field(default_factory=dict)
|
65 |
+
|
66 |
+
def __post_init__(self):
|
67 |
+
# Update latent dimensions based on image dimensions
|
68 |
+
self.model.latents_width = self.model.width // 8
|
69 |
+
self.model.latents_height = self.model.height // 8
|
70 |
+
|
71 |
+
# Default configuration instance
|
72 |
+
default_config = Config()
|
src/ddpm.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class DDPMSampler:
|
5 |
+
|
6 |
+
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
|
7 |
+
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
|
8 |
+
self.alphas = 1.0 - self.betas
|
9 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
10 |
+
self.one = torch.tensor(1.0)
|
11 |
+
|
12 |
+
self.generator = generator
|
13 |
+
self.num_train_timesteps = num_training_steps
|
14 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
|
15 |
+
|
16 |
+
def set_inference_timesteps(self, num_inference_steps=50):
|
17 |
+
self.num_inference_steps = num_inference_steps
|
18 |
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
19 |
+
inference_timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
20 |
+
self.timesteps = torch.from_numpy(inference_timesteps)
|
21 |
+
|
22 |
+
def _get_previous_timestep(self, timestep: int) -> int:
|
23 |
+
return timestep - self.num_train_timesteps // self.num_inference_steps
|
24 |
+
|
25 |
+
def _get_variance(self, timestep: int) -> torch.Tensor:
|
26 |
+
prev_timestep = self._get_previous_timestep(timestep)
|
27 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
28 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
29 |
+
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
|
30 |
+
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
|
31 |
+
return torch.clamp(variance, min=1e-20)
|
32 |
+
|
33 |
+
def set_strength(self, strength=1):
|
34 |
+
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
|
35 |
+
self.timesteps = self.timesteps[start_step:]
|
36 |
+
self.start_step = start_step
|
37 |
+
|
38 |
+
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
|
39 |
+
prev_timestep = self._get_previous_timestep(timestep)
|
40 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
41 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
42 |
+
beta_prod_t = 1 - alpha_prod_t
|
43 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
44 |
+
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
45 |
+
current_beta_t = 1 - current_alpha_t
|
46 |
+
|
47 |
+
pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
48 |
+
pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
|
49 |
+
current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
|
50 |
+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
|
51 |
+
|
52 |
+
variance = 0
|
53 |
+
if timestep > 0:
|
54 |
+
device = model_output.device
|
55 |
+
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
|
56 |
+
variance = (self._get_variance(timestep) ** 0.5) * noise
|
57 |
+
|
58 |
+
return pred_prev_sample + variance
|
59 |
+
|
60 |
+
def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
|
61 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
62 |
+
timesteps = timesteps.to(original_samples.device)
|
63 |
+
|
64 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
65 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
66 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
67 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
68 |
+
|
69 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
70 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
71 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
72 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
73 |
+
|
74 |
+
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
|
75 |
+
return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
76 |
+
|
src/decoder.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .attention import SelfAttention
|
5 |
+
|
6 |
+
class VAE_AttentionBlock(nn.Module):
|
7 |
+
def __init__(self, channels):
|
8 |
+
super().__init__()
|
9 |
+
self.groupnorm = nn.GroupNorm(32, channels)
|
10 |
+
self.attention = SelfAttention(1, channels)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
residue = x
|
14 |
+
x = self.groupnorm(x)
|
15 |
+
n, c, h, w = x.shape
|
16 |
+
x = x.view((n, c, h * w)).transpose(-1, -2)
|
17 |
+
x = self.attention(x)
|
18 |
+
x = x.transpose(-1, -2).view((n, c, h, w))
|
19 |
+
return x + residue
|
20 |
+
|
21 |
+
class VAE_ResidualBlock(nn.Module):
|
22 |
+
def __init__(self, in_channels, out_channels):
|
23 |
+
super().__init__()
|
24 |
+
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
|
25 |
+
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
26 |
+
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
|
27 |
+
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
28 |
+
self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
residue = x
|
32 |
+
x = self.groupnorm_1(x)
|
33 |
+
x = F.silu(x)
|
34 |
+
x = self.conv_1(x)
|
35 |
+
x = self.groupnorm_2(x)
|
36 |
+
x = F.silu(x)
|
37 |
+
x = self.conv_2(x)
|
38 |
+
return x + self.residual_layer(residue)
|
39 |
+
|
40 |
+
class VAE_Decoder(nn.Sequential):
|
41 |
+
def __init__(self):
|
42 |
+
super().__init__(
|
43 |
+
nn.Conv2d(4, 4, kernel_size=1, padding=0),
|
44 |
+
nn.Conv2d(4, 512, kernel_size=3, padding=1),
|
45 |
+
VAE_ResidualBlock(512, 512),
|
46 |
+
VAE_AttentionBlock(512),
|
47 |
+
VAE_ResidualBlock(512, 512),
|
48 |
+
VAE_ResidualBlock(512, 512),
|
49 |
+
VAE_ResidualBlock(512, 512),
|
50 |
+
VAE_ResidualBlock(512, 512),
|
51 |
+
nn.Upsample(scale_factor=2),
|
52 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
53 |
+
VAE_ResidualBlock(512, 512),
|
54 |
+
VAE_ResidualBlock(512, 512),
|
55 |
+
VAE_ResidualBlock(512, 512),
|
56 |
+
nn.Upsample(scale_factor=2),
|
57 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
58 |
+
VAE_ResidualBlock(512, 256),
|
59 |
+
VAE_ResidualBlock(256, 256),
|
60 |
+
VAE_ResidualBlock(256, 256),
|
61 |
+
nn.Upsample(scale_factor=2),
|
62 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
63 |
+
VAE_ResidualBlock(256, 128),
|
64 |
+
VAE_ResidualBlock(128, 128),
|
65 |
+
VAE_ResidualBlock(128, 128),
|
66 |
+
nn.GroupNorm(32, 128),
|
67 |
+
nn.SiLU(),
|
68 |
+
nn.Conv2d(128, 3, kernel_size=3, padding=1),
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x /= 0.18215
|
73 |
+
for module in self:
|
74 |
+
x = module(x)
|
75 |
+
return x
|
76 |
+
|
src/demo.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import model_loader
|
2 |
+
import pipeline
|
3 |
+
from PIL import Image
|
4 |
+
from pathlib import Path
|
5 |
+
from transformers import CLIPTokenizer
|
6 |
+
import torch
|
7 |
+
from config import Config, default_config, DeviceConfig
|
8 |
+
|
9 |
+
# Device configuration
|
10 |
+
ALLOW_CUDA = False
|
11 |
+
ALLOW_MPS = False
|
12 |
+
|
13 |
+
device = "cpu"
|
14 |
+
if torch.cuda.is_available() and ALLOW_CUDA:
|
15 |
+
device = "cuda"
|
16 |
+
elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
|
17 |
+
device = "mps"
|
18 |
+
print(f"Using device: {device}")
|
19 |
+
|
20 |
+
# Initialize configuration
|
21 |
+
config = Config(
|
22 |
+
device=DeviceConfig(device=device),
|
23 |
+
seed=42,
|
24 |
+
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
25 |
+
)
|
26 |
+
|
27 |
+
# Update diffusion parameters
|
28 |
+
config.diffusion.strength = 0.75
|
29 |
+
config.diffusion.cfg_scale = 8.0
|
30 |
+
config.diffusion.n_inference_steps = 50
|
31 |
+
|
32 |
+
# Load models with SE blocks enabled
|
33 |
+
model_file = "data/v1-5-pruned-emaonly.ckpt"
|
34 |
+
config.models = model_loader.load_models(model_file, device, use_se=True)
|
35 |
+
|
36 |
+
# Generate image
|
37 |
+
prompt = "A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars"
|
38 |
+
uncond_prompt = ""
|
39 |
+
|
40 |
+
output_image = pipeline.generate(
|
41 |
+
prompt=prompt,
|
42 |
+
uncond_prompt=uncond_prompt,
|
43 |
+
config=config
|
44 |
+
)
|
45 |
+
|
46 |
+
# Save output
|
47 |
+
output_image = Image.fromarray(output_image)
|
48 |
+
output_image.save("output.png")
|
src/diffusion.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .attention import SelfAttention, CrossAttention
|
5 |
+
|
6 |
+
class TimeEmbedding(nn.Module):
|
7 |
+
def __init__(self, n_embd):
|
8 |
+
super().__init__()
|
9 |
+
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
10 |
+
self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
x = F.silu(self.linear_1(x))
|
14 |
+
return self.linear_2(x)
|
15 |
+
|
16 |
+
class SqueezeExcitation(nn.Module):
|
17 |
+
def __init__(self, channels, reduction=16):
|
18 |
+
super().__init__()
|
19 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
20 |
+
self.fc = nn.Sequential(
|
21 |
+
nn.Linear(channels, channels // reduction, bias=False),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Linear(channels // reduction, channels, bias=False),
|
24 |
+
nn.Sigmoid()
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
b, c, _, _ = x.size()
|
29 |
+
y = self.avg_pool(x).view(b, c)
|
30 |
+
y = self.fc(y).view(b, c, 1, 1)
|
31 |
+
return x * y.expand_as(x)
|
32 |
+
|
33 |
+
class UNET_ResidualBlock(nn.Module):
|
34 |
+
def __init__(self, in_channels, out_channels, n_time=1280, use_se=False):
|
35 |
+
super().__init__()
|
36 |
+
self.groupnorm_feature = nn.GroupNorm(32, in_channels)
|
37 |
+
self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
38 |
+
self.linear_time = nn.Linear(n_time, out_channels)
|
39 |
+
self.groupnorm_merged = nn.GroupNorm(32, out_channels)
|
40 |
+
self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
41 |
+
self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
42 |
+
|
43 |
+
# Add Squeeze-Excitation blocks only if use_se is True
|
44 |
+
self.use_se = use_se
|
45 |
+
if use_se:
|
46 |
+
self.se1 = SqueezeExcitation(out_channels)
|
47 |
+
self.se2 = SqueezeExcitation(out_channels)
|
48 |
+
|
49 |
+
def forward(self, feature, time):
|
50 |
+
residue = feature
|
51 |
+
feature = F.silu(self.groupnorm_feature(feature))
|
52 |
+
feature = self.conv_feature(feature)
|
53 |
+
if self.use_se:
|
54 |
+
feature = self.se1(feature) # Apply SE after first conv
|
55 |
+
|
56 |
+
time = self.linear_time(F.silu(time))
|
57 |
+
merged = feature + time.unsqueeze(-1).unsqueeze(-1)
|
58 |
+
merged = F.silu(self.groupnorm_merged(merged))
|
59 |
+
merged = self.conv_merged(merged)
|
60 |
+
if self.use_se:
|
61 |
+
merged = self.se2(merged) # Apply SE after second conv
|
62 |
+
|
63 |
+
return merged + self.residual_layer(residue)
|
64 |
+
|
65 |
+
class UNET_AttentionBlock(nn.Module):
|
66 |
+
def __init__(self, n_head: int, n_embd: int, d_context=768):
|
67 |
+
super().__init__()
|
68 |
+
channels = n_head * n_embd
|
69 |
+
self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
|
70 |
+
self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
71 |
+
self.layernorm_1 = nn.LayerNorm(channels)
|
72 |
+
self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
|
73 |
+
self.layernorm_2 = nn.LayerNorm(channels)
|
74 |
+
self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
|
75 |
+
self.layernorm_3 = nn.LayerNorm(channels)
|
76 |
+
self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
|
77 |
+
self.linear_geglu_2 = nn.Linear(4 * channels, channels)
|
78 |
+
self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
79 |
+
|
80 |
+
def forward(self, x, context):
|
81 |
+
residue_long = x
|
82 |
+
x = self.conv_input(self.groupnorm(x))
|
83 |
+
n, c, h, w = x.shape
|
84 |
+
x = x.view((n, c, h * w)).transpose(-1, -2)
|
85 |
+
residue_short = x
|
86 |
+
x = self.attention_1(self.layernorm_1(x)) + residue_short
|
87 |
+
residue_short = x
|
88 |
+
x = self.attention_2(self.layernorm_2(x), context) + residue_short
|
89 |
+
residue_short = x
|
90 |
+
x, gate = self.linear_geglu_1(self.layernorm_3(x)).chunk(2, dim=-1)
|
91 |
+
x = self.linear_geglu_2(x * F.gelu(gate)) + residue_short
|
92 |
+
x = x.transpose(-1, -2).view((n, c, h, w))
|
93 |
+
return self.conv_output(x) + residue_long
|
94 |
+
|
95 |
+
class Upsample(nn.Module):
|
96 |
+
def __init__(self, channels):
|
97 |
+
super().__init__()
|
98 |
+
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
|
102 |
+
|
103 |
+
class SwitchSequential(nn.Sequential):
|
104 |
+
def forward(self, x, context, time):
|
105 |
+
for layer in self:
|
106 |
+
if isinstance(layer, UNET_AttentionBlock):
|
107 |
+
x = layer(x, context)
|
108 |
+
elif isinstance(layer, UNET_ResidualBlock):
|
109 |
+
x = layer(x, time)
|
110 |
+
else:
|
111 |
+
x = layer(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
class UNET(nn.Module):
|
115 |
+
def __init__(self, use_se=False):
|
116 |
+
super().__init__()
|
117 |
+
self.encoders = nn.ModuleList([
|
118 |
+
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
119 |
+
SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
120 |
+
SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
121 |
+
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
122 |
+
SwitchSequential(UNET_ResidualBlock(320, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
123 |
+
SwitchSequential(UNET_ResidualBlock(640, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
124 |
+
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
125 |
+
SwitchSequential(UNET_ResidualBlock(640, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
126 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
127 |
+
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
128 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
|
129 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
|
130 |
+
])
|
131 |
+
|
132 |
+
self.bottleneck = SwitchSequential(
|
133 |
+
UNET_ResidualBlock(1280, 1280, use_se=use_se),
|
134 |
+
UNET_AttentionBlock(8, 160),
|
135 |
+
UNET_ResidualBlock(1280, 1280, use_se=use_se),
|
136 |
+
)
|
137 |
+
|
138 |
+
self.decoders = nn.ModuleList([
|
139 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
|
140 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
|
141 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), Upsample(1280)),
|
142 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
143 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
144 |
+
SwitchSequential(UNET_ResidualBlock(1920, 1280, use_se=use_se), UNET_AttentionBlock(8, 160), Upsample(1280)),
|
145 |
+
SwitchSequential(UNET_ResidualBlock(1920, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
146 |
+
SwitchSequential(UNET_ResidualBlock(1280, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
147 |
+
SwitchSequential(UNET_ResidualBlock(960, 640, use_se=use_se), UNET_AttentionBlock(8, 80), Upsample(640)),
|
148 |
+
SwitchSequential(UNET_ResidualBlock(960, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
149 |
+
SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
150 |
+
SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
151 |
+
])
|
152 |
+
|
153 |
+
def forward(self, x, context, time):
|
154 |
+
skip_connections = []
|
155 |
+
for layers in self.encoders:
|
156 |
+
x = layers(x, context, time)
|
157 |
+
skip_connections.append(x)
|
158 |
+
|
159 |
+
x = self.bottleneck(x, context, time)
|
160 |
+
|
161 |
+
for layers in self.decoders:
|
162 |
+
x = torch.cat((x, skip_connections.pop()), dim=1)
|
163 |
+
x = layers(x, context, time)
|
164 |
+
|
165 |
+
return x
|
166 |
+
|
167 |
+
class UNET_OutputLayer(nn.Module):
|
168 |
+
def __init__(self, in_channels, out_channels):
|
169 |
+
super().__init__()
|
170 |
+
self.groupnorm = nn.GroupNorm(32, in_channels)
|
171 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = F.silu(self.groupnorm(x))
|
175 |
+
return self.conv(x)
|
176 |
+
|
177 |
+
class Diffusion(nn.Module):
|
178 |
+
def __init__(self, use_se=False):
|
179 |
+
super().__init__()
|
180 |
+
self.time_embedding = TimeEmbedding(320)
|
181 |
+
self.unet = UNET(use_se=use_se)
|
182 |
+
self.final = UNET_OutputLayer(320, 4)
|
183 |
+
|
184 |
+
def forward(self, latent, context, time):
|
185 |
+
time = self.time_embedding(time)
|
186 |
+
output = self.unet(latent, context, time)
|
187 |
+
return self.final(output)
|
src/encoder.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .decoder import VAE_AttentionBlock, VAE_ResidualBlock
|
5 |
+
|
6 |
+
class VAE_Encoder(nn.Sequential):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__(
|
9 |
+
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
10 |
+
VAE_ResidualBlock(128, 128),
|
11 |
+
VAE_ResidualBlock(128, 128),
|
12 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
|
13 |
+
VAE_ResidualBlock(128, 256),
|
14 |
+
VAE_ResidualBlock(256, 256),
|
15 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
|
16 |
+
VAE_ResidualBlock(256, 512),
|
17 |
+
VAE_ResidualBlock(512, 512),
|
18 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
|
19 |
+
VAE_ResidualBlock(512, 512),
|
20 |
+
VAE_ResidualBlock(512, 512),
|
21 |
+
VAE_ResidualBlock(512, 512),
|
22 |
+
VAE_AttentionBlock(512),
|
23 |
+
VAE_ResidualBlock(512, 512),
|
24 |
+
nn.GroupNorm(32, 512),
|
25 |
+
nn.SiLU(),
|
26 |
+
nn.Conv2d(512, 8, kernel_size=3, padding=1),
|
27 |
+
nn.Conv2d(8, 8, kernel_size=1, padding=0),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x, noise):
|
31 |
+
for module in self:
|
32 |
+
if getattr(module, 'stride', None) == (2, 2):
|
33 |
+
x = F.pad(x, (0, 1, 0, 1))
|
34 |
+
x = module(x)
|
35 |
+
mean, log_variance = torch.chunk(x, 2, dim=1)
|
36 |
+
log_variance = torch.clamp(log_variance, -30, 20)
|
37 |
+
variance = log_variance.exp()
|
38 |
+
stdev = variance.sqrt()
|
39 |
+
x = mean + stdev * noise
|
40 |
+
x *= 0.18215
|
41 |
+
return x
|
42 |
+
|
src/model_converter.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/model_loader.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .clip import CLIP
|
2 |
+
from .encoder import VAE_Encoder
|
3 |
+
from .decoder import VAE_Decoder
|
4 |
+
from .diffusion import Diffusion
|
5 |
+
|
6 |
+
from . import model_converter
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def load_models(ckpt_path, device, use_se=False):
|
10 |
+
state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
|
11 |
+
|
12 |
+
encoder = VAE_Encoder().to(device)
|
13 |
+
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
14 |
+
|
15 |
+
decoder = VAE_Decoder().to(device)
|
16 |
+
decoder.load_state_dict(state_dict['decoder'], strict=True)
|
17 |
+
|
18 |
+
# Initialize diffusion model with SE blocks disabled for loading pre-trained weights
|
19 |
+
diffusion = Diffusion(use_se=False).to(device)
|
20 |
+
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
21 |
+
|
22 |
+
# If SE blocks are requested, reinitialize the model with them
|
23 |
+
if use_se:
|
24 |
+
diffusion = Diffusion(use_se=True).to(device)
|
25 |
+
# Copy the weights from the loaded model
|
26 |
+
with torch.no_grad():
|
27 |
+
for name, param in diffusion.named_parameters():
|
28 |
+
if 'se' not in name: # Skip SE block parameters
|
29 |
+
if name in state_dict['diffusion']:
|
30 |
+
param.copy_(state_dict['diffusion'][name])
|
31 |
+
|
32 |
+
clip = CLIP().to(device)
|
33 |
+
clip.load_state_dict(state_dict['clip'], strict=True)
|
34 |
+
|
35 |
+
return {
|
36 |
+
'clip': clip,
|
37 |
+
'encoder': encoder,
|
38 |
+
'decoder': decoder,
|
39 |
+
'diffusion': diffusion,
|
40 |
+
}
|
src/pipeline.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from .ddpm import DDPMSampler
|
6 |
+
import logging
|
7 |
+
from .config import Config, default_config
|
8 |
+
|
9 |
+
WIDTH = 512
|
10 |
+
HEIGHT = 512
|
11 |
+
LATENTS_WIDTH = WIDTH // 8
|
12 |
+
LATENTS_HEIGHT = HEIGHT // 8
|
13 |
+
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
|
16 |
+
def generate(
|
17 |
+
prompt,
|
18 |
+
uncond_prompt=None,
|
19 |
+
input_image=None,
|
20 |
+
config: Config = default_config,
|
21 |
+
):
|
22 |
+
with torch.no_grad():
|
23 |
+
validate_strength(config.diffusion.strength)
|
24 |
+
generator = initialize_generator(config.seed, config.device.device)
|
25 |
+
context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, config.tokenizer, config.models["clip"], config.device.device)
|
26 |
+
latents = initialize_latents(input_image, config.diffusion.strength, generator, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps)
|
27 |
+
images = run_diffusion(latents, context, config.diffusion.do_cfg, config.diffusion.cfg_scale, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps, generator)
|
28 |
+
return postprocess_images(images)
|
29 |
+
|
30 |
+
def validate_strength(strength):
|
31 |
+
if not 0 < strength <= 1:
|
32 |
+
raise ValueError("Strength must be between 0 and 1")
|
33 |
+
|
34 |
+
def initialize_generator(seed, device):
|
35 |
+
generator = torch.Generator(device=device)
|
36 |
+
if seed is None:
|
37 |
+
generator.seed()
|
38 |
+
else:
|
39 |
+
generator.manual_seed(seed)
|
40 |
+
return generator
|
41 |
+
|
42 |
+
def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
|
43 |
+
clip.to(device)
|
44 |
+
if do_cfg:
|
45 |
+
cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
|
46 |
+
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
|
47 |
+
cond_context = clip(cond_tokens)
|
48 |
+
uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
|
49 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
|
50 |
+
uncond_context = clip(uncond_tokens)
|
51 |
+
context = torch.cat([cond_context, uncond_context])
|
52 |
+
else:
|
53 |
+
tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
|
54 |
+
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
|
55 |
+
context = clip(tokens)
|
56 |
+
return context
|
57 |
+
|
58 |
+
def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps):
|
59 |
+
if input_image is None:
|
60 |
+
# Initialize with random noise
|
61 |
+
latents = torch.randn((1, 4, 64, 64), generator=generator, device=device)
|
62 |
+
else:
|
63 |
+
# Initialize with encoded input image
|
64 |
+
latents = encode_image(input_image, models, device)
|
65 |
+
# Add noise based on strength
|
66 |
+
noise = torch.randn_like(latents, generator=generator)
|
67 |
+
latents = (1 - strength) * latents + strength * noise
|
68 |
+
return latents
|
69 |
+
|
70 |
+
def preprocess_image(input_image):
|
71 |
+
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
|
72 |
+
input_image_tensor = np.array(input_image_tensor)
|
73 |
+
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
|
74 |
+
input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
|
75 |
+
input_image_tensor = input_image_tensor.unsqueeze(0)
|
76 |
+
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
|
77 |
+
return input_image_tensor
|
78 |
+
|
79 |
+
def get_sampler(sampler_name, generator, n_inference_steps):
|
80 |
+
if sampler_name == "ddpm":
|
81 |
+
sampler = DDPMSampler(generator)
|
82 |
+
sampler.set_inference_timesteps(n_inference_steps)
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Unknown sampler value {sampler_name}.")
|
85 |
+
return sampler
|
86 |
+
|
87 |
+
def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
|
88 |
+
diffusion = models["diffusion"]
|
89 |
+
diffusion.to(device)
|
90 |
+
sampler = get_sampler(sampler_name, generator, n_inference_steps)
|
91 |
+
timesteps = tqdm(sampler.timesteps)
|
92 |
+
for timestep in timesteps:
|
93 |
+
time_embedding = get_time_embedding(timestep).to(device)
|
94 |
+
model_input = latents.repeat(2, 1, 1, 1) if do_cfg else latents
|
95 |
+
model_output = diffusion(model_input, context, time_embedding)
|
96 |
+
if do_cfg:
|
97 |
+
output_cond, output_uncond = model_output.chunk(2)
|
98 |
+
model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
|
99 |
+
latents = sampler.step(timestep, latents, model_output)
|
100 |
+
decoder = models["decoder"]
|
101 |
+
decoder.to(device)
|
102 |
+
images = decoder(latents)
|
103 |
+
return images
|
104 |
+
|
105 |
+
def postprocess_images(images):
|
106 |
+
images = rescale(images, (-1, 1), (0, 255), clamp=True)
|
107 |
+
images = images.permute(0, 2, 3, 1)
|
108 |
+
images = images.to("cpu", torch.uint8).numpy()
|
109 |
+
return images[0]
|
110 |
+
|
111 |
+
def rescale(x, old_range, new_range, clamp=False):
|
112 |
+
old_min, old_max = old_range
|
113 |
+
new_min, new_max = new_range
|
114 |
+
x -= old_min
|
115 |
+
x *= (new_max - new_min) / (old_max - old_min)
|
116 |
+
x += new_min
|
117 |
+
if clamp:
|
118 |
+
x = x.clamp(new_min, new_max)
|
119 |
+
return x
|
120 |
+
|
121 |
+
def get_time_embedding(timestep):
|
122 |
+
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
|
123 |
+
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
124 |
+
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|