Spaces:
Paused
Paused
Delete src
Browse files- src/__init__.py +0 -0
- src/__pycache__/attention.cpython-312.pyc +0 -0
- src/__pycache__/clip.cpython-312.pyc +0 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/ddpm.cpython-312.pyc +0 -0
- src/__pycache__/decoder.cpython-312.pyc +0 -0
- src/__pycache__/diffusion.cpython-312.pyc +0 -0
- src/__pycache__/encoder.cpython-312.pyc +0 -0
- src/__pycache__/model_converter.cpython-312.pyc +0 -3
- src/__pycache__/model_loader.cpython-312.pyc +0 -0
- src/__pycache__/pipeline.cpython-312.pyc +0 -0
- src/attention.py +0 -69
- src/clip.py +0 -54
- src/config.py +0 -72
- src/ddpm.py +0 -76
- src/decoder.py +0 -76
- src/demo.py +0 -48
- src/diffusion.py +0 -187
- src/encoder.py +0 -42
- src/model_converter.py +0 -0
- src/model_loader.py +0 -40
- src/pipeline.py +0 -123
src/__init__.py
DELETED
File without changes
|
src/__pycache__/attention.cpython-312.pyc
DELETED
Binary file (4.68 kB)
|
|
src/__pycache__/clip.cpython-312.pyc
DELETED
Binary file (4 kB)
|
|
src/__pycache__/config.cpython-312.pyc
DELETED
Binary file (3.39 kB)
|
|
src/__pycache__/ddpm.cpython-312.pyc
DELETED
Binary file (6.45 kB)
|
|
src/__pycache__/decoder.cpython-312.pyc
DELETED
Binary file (4.9 kB)
|
|
src/__pycache__/diffusion.cpython-312.pyc
DELETED
Binary file (14.2 kB)
|
|
src/__pycache__/encoder.cpython-312.pyc
DELETED
Binary file (2.53 kB)
|
|
src/__pycache__/model_converter.cpython-312.pyc
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:65ec381ffd1ecb7e843d62f4c98aab8630d90d1273f051e6761d1b9837281628
|
3 |
-
size 170116
|
|
|
|
|
|
|
|
src/__pycache__/model_loader.cpython-312.pyc
DELETED
Binary file (1.84 kB)
|
|
src/__pycache__/pipeline.cpython-312.pyc
DELETED
Binary file (8.03 kB)
|
|
src/attention.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
from torch.nn import functional as F
|
4 |
-
import math
|
5 |
-
|
6 |
-
class SelfAttention(nn.Module):
|
7 |
-
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
|
8 |
-
super().__init__()
|
9 |
-
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
|
10 |
-
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
11 |
-
self.n_heads = n_heads
|
12 |
-
self.d_head = d_embed // n_heads
|
13 |
-
|
14 |
-
def forward(self, x, causal_mask=False):
|
15 |
-
input_shape = x.shape
|
16 |
-
batch_size, sequence_length, d_embed = input_shape
|
17 |
-
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
|
18 |
-
|
19 |
-
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
20 |
-
q = q.view(interim_shape).transpose(1, 2)
|
21 |
-
k = k.view(interim_shape).transpose(1, 2)
|
22 |
-
v = v.view(interim_shape).transpose(1, 2)
|
23 |
-
|
24 |
-
weight = q @ k.transpose(-1, -2)
|
25 |
-
|
26 |
-
if causal_mask:
|
27 |
-
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
|
28 |
-
weight.masked_fill_(mask, -torch.inf)
|
29 |
-
|
30 |
-
weight /= math.sqrt(self.d_head)
|
31 |
-
weight = F.softmax(weight, dim=-1)
|
32 |
-
output = weight @ v
|
33 |
-
output = output.transpose(1, 2).reshape(input_shape)
|
34 |
-
output = self.out_proj(output)
|
35 |
-
|
36 |
-
return output
|
37 |
-
|
38 |
-
class CrossAttention(nn.Module):
|
39 |
-
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
|
40 |
-
super().__init__()
|
41 |
-
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
|
42 |
-
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
43 |
-
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
44 |
-
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
45 |
-
self.n_heads = n_heads
|
46 |
-
self.d_head = d_embed // n_heads
|
47 |
-
|
48 |
-
def forward(self, x, y):
|
49 |
-
input_shape = x.shape
|
50 |
-
batch_size, sequence_length, d_embed = input_shape
|
51 |
-
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
|
52 |
-
|
53 |
-
q = self.q_proj(x)
|
54 |
-
k = self.k_proj(y)
|
55 |
-
v = self.v_proj(y)
|
56 |
-
|
57 |
-
q = q.view(interim_shape).transpose(1, 2)
|
58 |
-
k = k.view(interim_shape).transpose(1, 2)
|
59 |
-
v = v.view(interim_shape).transpose(1, 2)
|
60 |
-
|
61 |
-
weight = q @ k.transpose(-1, -2)
|
62 |
-
weight /= math.sqrt(self.d_head)
|
63 |
-
weight = F.softmax(weight, dim=-1)
|
64 |
-
output = weight @ v
|
65 |
-
output = output.transpose(1, 2).contiguous().view(input_shape)
|
66 |
-
output = self.out_proj(output)
|
67 |
-
|
68 |
-
return output
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clip.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from attention import SelfAttention
|
5 |
-
|
6 |
-
class CLIPEmbedding(nn.Module):
|
7 |
-
def __init__(self, n_vocab: int, n_embd: int, n_token: int):
|
8 |
-
super().__init__()
|
9 |
-
self.token_embedding = nn.Embedding(n_vocab, n_embd)
|
10 |
-
self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
|
11 |
-
|
12 |
-
def forward(self, tokens):
|
13 |
-
x = self.token_embedding(tokens)
|
14 |
-
x += self.position_embedding
|
15 |
-
return x
|
16 |
-
|
17 |
-
class CLIPLayer(nn.Module):
|
18 |
-
def __init__(self, n_head: int, n_embd: int):
|
19 |
-
super().__init__()
|
20 |
-
self.layernorm_1 = nn.LayerNorm(n_embd)
|
21 |
-
self.attention = SelfAttention(n_head, n_embd)
|
22 |
-
self.layernorm_2 = nn.LayerNorm(n_embd)
|
23 |
-
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
24 |
-
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
|
25 |
-
self.activation = nn.GELU()
|
26 |
-
|
27 |
-
def forward(self, x):
|
28 |
-
residue = x
|
29 |
-
x = self.layernorm_1(x)
|
30 |
-
x = self.attention(x, causal_mask=True)
|
31 |
-
x += residue
|
32 |
-
|
33 |
-
residue = x
|
34 |
-
x = self.layernorm_2(x)
|
35 |
-
x = self.linear_1(x)
|
36 |
-
x = self.activation(x)
|
37 |
-
x = self.linear_2(x)
|
38 |
-
x += residue
|
39 |
-
|
40 |
-
return x
|
41 |
-
|
42 |
-
class CLIP(nn.Module):
|
43 |
-
def __init__(self):
|
44 |
-
super().__init__()
|
45 |
-
self.embedding = CLIPEmbedding(49408, 768, 77)
|
46 |
-
self.layers = nn.ModuleList([CLIPLayer(12, 768) for _ in range(12)])
|
47 |
-
self.layernorm = nn.LayerNorm(768)
|
48 |
-
|
49 |
-
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
50 |
-
state = self.embedding(tokens)
|
51 |
-
for layer in self.layers:
|
52 |
-
state = layer(state)
|
53 |
-
output = self.layernorm(state)
|
54 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass, field
|
2 |
-
from typing import Optional, Dict, Any
|
3 |
-
import torch
|
4 |
-
|
5 |
-
@dataclass
|
6 |
-
class ModelConfig:
|
7 |
-
# Image dimensions
|
8 |
-
width: int = 512
|
9 |
-
height: int = 512
|
10 |
-
latents_width: int = 64 # width // 8
|
11 |
-
latents_height: int = 64 # height // 8
|
12 |
-
|
13 |
-
# Model architecture parameters
|
14 |
-
n_embd: int = 1280
|
15 |
-
n_head: int = 8
|
16 |
-
d_context: int = 768
|
17 |
-
|
18 |
-
# UNet parameters
|
19 |
-
n_time: int = 1280
|
20 |
-
n_channels: int = 4
|
21 |
-
n_residual_blocks: int = 2
|
22 |
-
|
23 |
-
# Attention parameters
|
24 |
-
attention_heads: int = 8
|
25 |
-
attention_dim: int = 1280
|
26 |
-
|
27 |
-
@dataclass
|
28 |
-
class DiffusionConfig:
|
29 |
-
# Sampling parameters
|
30 |
-
n_inference_steps: int = 50
|
31 |
-
guidance_scale: float = 7.5
|
32 |
-
strength: float = 0.8
|
33 |
-
|
34 |
-
# Sampler configuration
|
35 |
-
sampler_name: str = "ddpm"
|
36 |
-
beta_start: float = 0.00085
|
37 |
-
beta_end: float = 0.0120
|
38 |
-
beta_schedule: str = "linear"
|
39 |
-
|
40 |
-
# Conditioning parameters
|
41 |
-
do_cfg: bool = True
|
42 |
-
cfg_scale: float = 7.5
|
43 |
-
|
44 |
-
@dataclass
|
45 |
-
class DeviceConfig:
|
46 |
-
device: Optional[str] = None
|
47 |
-
idle_device: Optional[str] = None
|
48 |
-
|
49 |
-
def __post_init__(self):
|
50 |
-
if self.device is None:
|
51 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
-
if self.idle_device is None:
|
53 |
-
self.idle_device = "cpu"
|
54 |
-
|
55 |
-
@dataclass
|
56 |
-
class Config:
|
57 |
-
model: ModelConfig = field(default_factory=ModelConfig)
|
58 |
-
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
|
59 |
-
device: DeviceConfig = field(default_factory=DeviceConfig)
|
60 |
-
|
61 |
-
# Additional settings
|
62 |
-
seed: Optional[int] = None
|
63 |
-
tokenizer: Optional[Any] = None
|
64 |
-
models: Dict[str, Any] = field(default_factory=dict)
|
65 |
-
|
66 |
-
def __post_init__(self):
|
67 |
-
# Update latent dimensions based on image dimensions
|
68 |
-
self.model.latents_width = self.model.width // 8
|
69 |
-
self.model.latents_height = self.model.height // 8
|
70 |
-
|
71 |
-
# Default configuration instance
|
72 |
-
default_config = Config()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ddpm.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
class DDPMSampler:
|
5 |
-
|
6 |
-
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
|
7 |
-
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
|
8 |
-
self.alphas = 1.0 - self.betas
|
9 |
-
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
10 |
-
self.one = torch.tensor(1.0)
|
11 |
-
|
12 |
-
self.generator = generator
|
13 |
-
self.num_train_timesteps = num_training_steps
|
14 |
-
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
|
15 |
-
|
16 |
-
def set_inference_timesteps(self, num_inference_steps=50):
|
17 |
-
self.num_inference_steps = num_inference_steps
|
18 |
-
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
19 |
-
inference_timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
20 |
-
self.timesteps = torch.from_numpy(inference_timesteps)
|
21 |
-
|
22 |
-
def _get_previous_timestep(self, timestep: int) -> int:
|
23 |
-
return timestep - self.num_train_timesteps // self.num_inference_steps
|
24 |
-
|
25 |
-
def _get_variance(self, timestep: int) -> torch.Tensor:
|
26 |
-
prev_timestep = self._get_previous_timestep(timestep)
|
27 |
-
alpha_prod_t = self.alphas_cumprod[timestep]
|
28 |
-
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
29 |
-
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
|
30 |
-
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
|
31 |
-
return torch.clamp(variance, min=1e-20)
|
32 |
-
|
33 |
-
def set_strength(self, strength=1):
|
34 |
-
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
|
35 |
-
self.timesteps = self.timesteps[start_step:]
|
36 |
-
self.start_step = start_step
|
37 |
-
|
38 |
-
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
|
39 |
-
prev_timestep = self._get_previous_timestep(timestep)
|
40 |
-
alpha_prod_t = self.alphas_cumprod[timestep]
|
41 |
-
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
42 |
-
beta_prod_t = 1 - alpha_prod_t
|
43 |
-
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
44 |
-
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
45 |
-
current_beta_t = 1 - current_alpha_t
|
46 |
-
|
47 |
-
pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
48 |
-
pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
|
49 |
-
current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
|
50 |
-
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
|
51 |
-
|
52 |
-
variance = 0
|
53 |
-
if timestep > 0:
|
54 |
-
device = model_output.device
|
55 |
-
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
|
56 |
-
variance = (self._get_variance(timestep) ** 0.5) * noise
|
57 |
-
|
58 |
-
return pred_prev_sample + variance
|
59 |
-
|
60 |
-
def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
|
61 |
-
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
62 |
-
timesteps = timesteps.to(original_samples.device)
|
63 |
-
|
64 |
-
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
65 |
-
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
66 |
-
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
67 |
-
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
68 |
-
|
69 |
-
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
70 |
-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
71 |
-
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
72 |
-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
73 |
-
|
74 |
-
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
|
75 |
-
return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/decoder.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
from torch.nn import functional as F
|
4 |
-
from attention import SelfAttention
|
5 |
-
|
6 |
-
class VAE_AttentionBlock(nn.Module):
|
7 |
-
def __init__(self, channels):
|
8 |
-
super().__init__()
|
9 |
-
self.groupnorm = nn.GroupNorm(32, channels)
|
10 |
-
self.attention = SelfAttention(1, channels)
|
11 |
-
|
12 |
-
def forward(self, x):
|
13 |
-
residue = x
|
14 |
-
x = self.groupnorm(x)
|
15 |
-
n, c, h, w = x.shape
|
16 |
-
x = x.view((n, c, h * w)).transpose(-1, -2)
|
17 |
-
x = self.attention(x)
|
18 |
-
x = x.transpose(-1, -2).view((n, c, h, w))
|
19 |
-
return x + residue
|
20 |
-
|
21 |
-
class VAE_ResidualBlock(nn.Module):
|
22 |
-
def __init__(self, in_channels, out_channels):
|
23 |
-
super().__init__()
|
24 |
-
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
|
25 |
-
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
26 |
-
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
|
27 |
-
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
28 |
-
self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
29 |
-
|
30 |
-
def forward(self, x):
|
31 |
-
residue = x
|
32 |
-
x = self.groupnorm_1(x)
|
33 |
-
x = F.silu(x)
|
34 |
-
x = self.conv_1(x)
|
35 |
-
x = self.groupnorm_2(x)
|
36 |
-
x = F.silu(x)
|
37 |
-
x = self.conv_2(x)
|
38 |
-
return x + self.residual_layer(residue)
|
39 |
-
|
40 |
-
class VAE_Decoder(nn.Sequential):
|
41 |
-
def __init__(self):
|
42 |
-
super().__init__(
|
43 |
-
nn.Conv2d(4, 4, kernel_size=1, padding=0),
|
44 |
-
nn.Conv2d(4, 512, kernel_size=3, padding=1),
|
45 |
-
VAE_ResidualBlock(512, 512),
|
46 |
-
VAE_AttentionBlock(512),
|
47 |
-
VAE_ResidualBlock(512, 512),
|
48 |
-
VAE_ResidualBlock(512, 512),
|
49 |
-
VAE_ResidualBlock(512, 512),
|
50 |
-
VAE_ResidualBlock(512, 512),
|
51 |
-
nn.Upsample(scale_factor=2),
|
52 |
-
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
53 |
-
VAE_ResidualBlock(512, 512),
|
54 |
-
VAE_ResidualBlock(512, 512),
|
55 |
-
VAE_ResidualBlock(512, 512),
|
56 |
-
nn.Upsample(scale_factor=2),
|
57 |
-
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
58 |
-
VAE_ResidualBlock(512, 256),
|
59 |
-
VAE_ResidualBlock(256, 256),
|
60 |
-
VAE_ResidualBlock(256, 256),
|
61 |
-
nn.Upsample(scale_factor=2),
|
62 |
-
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
63 |
-
VAE_ResidualBlock(256, 128),
|
64 |
-
VAE_ResidualBlock(128, 128),
|
65 |
-
VAE_ResidualBlock(128, 128),
|
66 |
-
nn.GroupNorm(32, 128),
|
67 |
-
nn.SiLU(),
|
68 |
-
nn.Conv2d(128, 3, kernel_size=3, padding=1),
|
69 |
-
)
|
70 |
-
|
71 |
-
def forward(self, x):
|
72 |
-
x /= 0.18215
|
73 |
-
for module in self:
|
74 |
-
x = module(x)
|
75 |
-
return x
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/demo.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
import model_loader
|
2 |
-
import pipeline
|
3 |
-
from PIL import Image
|
4 |
-
from pathlib import Path
|
5 |
-
from transformers import CLIPTokenizer
|
6 |
-
import torch
|
7 |
-
from config import Config, default_config, DeviceConfig
|
8 |
-
|
9 |
-
# Device configuration
|
10 |
-
ALLOW_CUDA = False
|
11 |
-
ALLOW_MPS = False
|
12 |
-
|
13 |
-
device = "cpu"
|
14 |
-
if torch.cuda.is_available() and ALLOW_CUDA:
|
15 |
-
device = "cuda"
|
16 |
-
elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
|
17 |
-
device = "mps"
|
18 |
-
print(f"Using device: {device}")
|
19 |
-
|
20 |
-
# Initialize configuration
|
21 |
-
config = Config(
|
22 |
-
device=DeviceConfig(device=device),
|
23 |
-
seed=42,
|
24 |
-
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
25 |
-
)
|
26 |
-
|
27 |
-
# Update diffusion parameters
|
28 |
-
config.diffusion.strength = 0.75
|
29 |
-
config.diffusion.cfg_scale = 8.0
|
30 |
-
config.diffusion.n_inference_steps = 50
|
31 |
-
|
32 |
-
# Load models with SE blocks enabled
|
33 |
-
model_file = "data/v1-5-pruned-emaonly.ckpt"
|
34 |
-
config.models = model_loader.load_models(model_file, device, use_se=True)
|
35 |
-
|
36 |
-
# Generate image
|
37 |
-
prompt = "A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars"
|
38 |
-
uncond_prompt = ""
|
39 |
-
|
40 |
-
output_image = pipeline.generate(
|
41 |
-
prompt=prompt,
|
42 |
-
uncond_prompt=uncond_prompt,
|
43 |
-
config=config
|
44 |
-
)
|
45 |
-
|
46 |
-
# Save output
|
47 |
-
output_image = Image.fromarray(output_image)
|
48 |
-
output_image.save("output.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/diffusion.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
from torch.nn import functional as F
|
4 |
-
from attention import SelfAttention, CrossAttention
|
5 |
-
|
6 |
-
class TimeEmbedding(nn.Module):
|
7 |
-
def __init__(self, n_embd):
|
8 |
-
super().__init__()
|
9 |
-
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
10 |
-
self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
|
11 |
-
|
12 |
-
def forward(self, x):
|
13 |
-
x = F.silu(self.linear_1(x))
|
14 |
-
return self.linear_2(x)
|
15 |
-
|
16 |
-
class SqueezeExcitation(nn.Module):
|
17 |
-
def __init__(self, channels, reduction=16):
|
18 |
-
super().__init__()
|
19 |
-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
20 |
-
self.fc = nn.Sequential(
|
21 |
-
nn.Linear(channels, channels // reduction, bias=False),
|
22 |
-
nn.ReLU(inplace=True),
|
23 |
-
nn.Linear(channels // reduction, channels, bias=False),
|
24 |
-
nn.Sigmoid()
|
25 |
-
)
|
26 |
-
|
27 |
-
def forward(self, x):
|
28 |
-
b, c, _, _ = x.size()
|
29 |
-
y = self.avg_pool(x).view(b, c)
|
30 |
-
y = self.fc(y).view(b, c, 1, 1)
|
31 |
-
return x * y.expand_as(x)
|
32 |
-
|
33 |
-
class UNET_ResidualBlock(nn.Module):
|
34 |
-
def __init__(self, in_channels, out_channels, n_time=1280, use_se=False):
|
35 |
-
super().__init__()
|
36 |
-
self.groupnorm_feature = nn.GroupNorm(32, in_channels)
|
37 |
-
self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
38 |
-
self.linear_time = nn.Linear(n_time, out_channels)
|
39 |
-
self.groupnorm_merged = nn.GroupNorm(32, out_channels)
|
40 |
-
self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
41 |
-
self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
42 |
-
|
43 |
-
# Add Squeeze-Excitation blocks only if use_se is True
|
44 |
-
self.use_se = use_se
|
45 |
-
if use_se:
|
46 |
-
self.se1 = SqueezeExcitation(out_channels)
|
47 |
-
self.se2 = SqueezeExcitation(out_channels)
|
48 |
-
|
49 |
-
def forward(self, feature, time):
|
50 |
-
residue = feature
|
51 |
-
feature = F.silu(self.groupnorm_feature(feature))
|
52 |
-
feature = self.conv_feature(feature)
|
53 |
-
if self.use_se:
|
54 |
-
feature = self.se1(feature) # Apply SE after first conv
|
55 |
-
|
56 |
-
time = self.linear_time(F.silu(time))
|
57 |
-
merged = feature + time.unsqueeze(-1).unsqueeze(-1)
|
58 |
-
merged = F.silu(self.groupnorm_merged(merged))
|
59 |
-
merged = self.conv_merged(merged)
|
60 |
-
if self.use_se:
|
61 |
-
merged = self.se2(merged) # Apply SE after second conv
|
62 |
-
|
63 |
-
return merged + self.residual_layer(residue)
|
64 |
-
|
65 |
-
class UNET_AttentionBlock(nn.Module):
|
66 |
-
def __init__(self, n_head: int, n_embd: int, d_context=768):
|
67 |
-
super().__init__()
|
68 |
-
channels = n_head * n_embd
|
69 |
-
self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
|
70 |
-
self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
71 |
-
self.layernorm_1 = nn.LayerNorm(channels)
|
72 |
-
self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
|
73 |
-
self.layernorm_2 = nn.LayerNorm(channels)
|
74 |
-
self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
|
75 |
-
self.layernorm_3 = nn.LayerNorm(channels)
|
76 |
-
self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
|
77 |
-
self.linear_geglu_2 = nn.Linear(4 * channels, channels)
|
78 |
-
self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
79 |
-
|
80 |
-
def forward(self, x, context):
|
81 |
-
residue_long = x
|
82 |
-
x = self.conv_input(self.groupnorm(x))
|
83 |
-
n, c, h, w = x.shape
|
84 |
-
x = x.view((n, c, h * w)).transpose(-1, -2)
|
85 |
-
residue_short = x
|
86 |
-
x = self.attention_1(self.layernorm_1(x)) + residue_short
|
87 |
-
residue_short = x
|
88 |
-
x = self.attention_2(self.layernorm_2(x), context) + residue_short
|
89 |
-
residue_short = x
|
90 |
-
x, gate = self.linear_geglu_1(self.layernorm_3(x)).chunk(2, dim=-1)
|
91 |
-
x = self.linear_geglu_2(x * F.gelu(gate)) + residue_short
|
92 |
-
x = x.transpose(-1, -2).view((n, c, h, w))
|
93 |
-
return self.conv_output(x) + residue_long
|
94 |
-
|
95 |
-
class Upsample(nn.Module):
|
96 |
-
def __init__(self, channels):
|
97 |
-
super().__init__()
|
98 |
-
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
99 |
-
|
100 |
-
def forward(self, x):
|
101 |
-
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
|
102 |
-
|
103 |
-
class SwitchSequential(nn.Sequential):
|
104 |
-
def forward(self, x, context, time):
|
105 |
-
for layer in self:
|
106 |
-
if isinstance(layer, UNET_AttentionBlock):
|
107 |
-
x = layer(x, context)
|
108 |
-
elif isinstance(layer, UNET_ResidualBlock):
|
109 |
-
x = layer(x, time)
|
110 |
-
else:
|
111 |
-
x = layer(x)
|
112 |
-
return x
|
113 |
-
|
114 |
-
class UNET(nn.Module):
|
115 |
-
def __init__(self, use_se=False):
|
116 |
-
super().__init__()
|
117 |
-
self.encoders = nn.ModuleList([
|
118 |
-
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
119 |
-
SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
120 |
-
SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
121 |
-
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
122 |
-
SwitchSequential(UNET_ResidualBlock(320, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
123 |
-
SwitchSequential(UNET_ResidualBlock(640, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
124 |
-
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
125 |
-
SwitchSequential(UNET_ResidualBlock(640, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
126 |
-
SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
127 |
-
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
128 |
-
SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
|
129 |
-
SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
|
130 |
-
])
|
131 |
-
|
132 |
-
self.bottleneck = SwitchSequential(
|
133 |
-
UNET_ResidualBlock(1280, 1280, use_se=use_se),
|
134 |
-
UNET_AttentionBlock(8, 160),
|
135 |
-
UNET_ResidualBlock(1280, 1280, use_se=use_se),
|
136 |
-
)
|
137 |
-
|
138 |
-
self.decoders = nn.ModuleList([
|
139 |
-
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
|
140 |
-
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
|
141 |
-
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), Upsample(1280)),
|
142 |
-
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
143 |
-
SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
|
144 |
-
SwitchSequential(UNET_ResidualBlock(1920, 1280, use_se=use_se), UNET_AttentionBlock(8, 160), Upsample(1280)),
|
145 |
-
SwitchSequential(UNET_ResidualBlock(1920, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
146 |
-
SwitchSequential(UNET_ResidualBlock(1280, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
|
147 |
-
SwitchSequential(UNET_ResidualBlock(960, 640, use_se=use_se), UNET_AttentionBlock(8, 80), Upsample(640)),
|
148 |
-
SwitchSequential(UNET_ResidualBlock(960, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
149 |
-
SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
150 |
-
SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
|
151 |
-
])
|
152 |
-
|
153 |
-
def forward(self, x, context, time):
|
154 |
-
skip_connections = []
|
155 |
-
for layers in self.encoders:
|
156 |
-
x = layers(x, context, time)
|
157 |
-
skip_connections.append(x)
|
158 |
-
|
159 |
-
x = self.bottleneck(x, context, time)
|
160 |
-
|
161 |
-
for layers in self.decoders:
|
162 |
-
x = torch.cat((x, skip_connections.pop()), dim=1)
|
163 |
-
x = layers(x, context, time)
|
164 |
-
|
165 |
-
return x
|
166 |
-
|
167 |
-
class UNET_OutputLayer(nn.Module):
|
168 |
-
def __init__(self, in_channels, out_channels):
|
169 |
-
super().__init__()
|
170 |
-
self.groupnorm = nn.GroupNorm(32, in_channels)
|
171 |
-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
172 |
-
|
173 |
-
def forward(self, x):
|
174 |
-
x = F.silu(self.groupnorm(x))
|
175 |
-
return self.conv(x)
|
176 |
-
|
177 |
-
class Diffusion(nn.Module):
|
178 |
-
def __init__(self, use_se=False):
|
179 |
-
super().__init__()
|
180 |
-
self.time_embedding = TimeEmbedding(320)
|
181 |
-
self.unet = UNET(use_se=use_se)
|
182 |
-
self.final = UNET_OutputLayer(320, 4)
|
183 |
-
|
184 |
-
def forward(self, latent, context, time):
|
185 |
-
time = self.time_embedding(time)
|
186 |
-
output = self.unet(latent, context, time)
|
187 |
-
return self.final(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/encoder.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
from torch.nn import functional as F
|
4 |
-
from decoder import VAE_AttentionBlock, VAE_ResidualBlock
|
5 |
-
|
6 |
-
class VAE_Encoder(nn.Sequential):
|
7 |
-
def __init__(self):
|
8 |
-
super().__init__(
|
9 |
-
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
10 |
-
VAE_ResidualBlock(128, 128),
|
11 |
-
VAE_ResidualBlock(128, 128),
|
12 |
-
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
|
13 |
-
VAE_ResidualBlock(128, 256),
|
14 |
-
VAE_ResidualBlock(256, 256),
|
15 |
-
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
|
16 |
-
VAE_ResidualBlock(256, 512),
|
17 |
-
VAE_ResidualBlock(512, 512),
|
18 |
-
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
|
19 |
-
VAE_ResidualBlock(512, 512),
|
20 |
-
VAE_ResidualBlock(512, 512),
|
21 |
-
VAE_ResidualBlock(512, 512),
|
22 |
-
VAE_AttentionBlock(512),
|
23 |
-
VAE_ResidualBlock(512, 512),
|
24 |
-
nn.GroupNorm(32, 512),
|
25 |
-
nn.SiLU(),
|
26 |
-
nn.Conv2d(512, 8, kernel_size=3, padding=1),
|
27 |
-
nn.Conv2d(8, 8, kernel_size=1, padding=0),
|
28 |
-
)
|
29 |
-
|
30 |
-
def forward(self, x, noise):
|
31 |
-
for module in self:
|
32 |
-
if getattr(module, 'stride', None) == (2, 2):
|
33 |
-
x = F.pad(x, (0, 1, 0, 1))
|
34 |
-
x = module(x)
|
35 |
-
mean, log_variance = torch.chunk(x, 2, dim=1)
|
36 |
-
log_variance = torch.clamp(log_variance, -30, 20)
|
37 |
-
variance = log_variance.exp()
|
38 |
-
stdev = variance.sqrt()
|
39 |
-
x = mean + stdev * noise
|
40 |
-
x *= 0.18215
|
41 |
-
return x
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model_converter.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
src/model_loader.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
from clip import CLIP
|
2 |
-
from encoder import VAE_Encoder
|
3 |
-
from decoder import VAE_Decoder
|
4 |
-
from diffusion import Diffusion
|
5 |
-
|
6 |
-
import model_converter
|
7 |
-
import torch
|
8 |
-
|
9 |
-
def load_models(ckpt_path, device, use_se=False):
|
10 |
-
state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
|
11 |
-
|
12 |
-
encoder = VAE_Encoder().to(device)
|
13 |
-
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
14 |
-
|
15 |
-
decoder = VAE_Decoder().to(device)
|
16 |
-
decoder.load_state_dict(state_dict['decoder'], strict=True)
|
17 |
-
|
18 |
-
# Initialize diffusion model with SE blocks disabled for loading pre-trained weights
|
19 |
-
diffusion = Diffusion(use_se=False).to(device)
|
20 |
-
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
21 |
-
|
22 |
-
# If SE blocks are requested, reinitialize the model with them
|
23 |
-
if use_se:
|
24 |
-
diffusion = Diffusion(use_se=True).to(device)
|
25 |
-
# Copy the weights from the loaded model
|
26 |
-
with torch.no_grad():
|
27 |
-
for name, param in diffusion.named_parameters():
|
28 |
-
if 'se' not in name: # Skip SE block parameters
|
29 |
-
if name in state_dict['diffusion']:
|
30 |
-
param.copy_(state_dict['diffusion'][name])
|
31 |
-
|
32 |
-
clip = CLIP().to(device)
|
33 |
-
clip.load_state_dict(state_dict['clip'], strict=True)
|
34 |
-
|
35 |
-
return {
|
36 |
-
'clip': clip,
|
37 |
-
'encoder': encoder,
|
38 |
-
'decoder': decoder,
|
39 |
-
'diffusion': diffusion,
|
40 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/pipeline.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
from tqdm import tqdm
|
4 |
-
from ddpm import DDPMSampler
|
5 |
-
import logging
|
6 |
-
from config import Config, default_config
|
7 |
-
|
8 |
-
WIDTH = 512
|
9 |
-
HEIGHT = 512
|
10 |
-
LATENTS_WIDTH = WIDTH // 8
|
11 |
-
LATENTS_HEIGHT = HEIGHT // 8
|
12 |
-
|
13 |
-
logging.basicConfig(level=logging.INFO)
|
14 |
-
|
15 |
-
def generate(
|
16 |
-
prompt,
|
17 |
-
uncond_prompt=None,
|
18 |
-
input_image=None,
|
19 |
-
config: Config = default_config,
|
20 |
-
):
|
21 |
-
with torch.no_grad():
|
22 |
-
validate_strength(config.diffusion.strength)
|
23 |
-
generator = initialize_generator(config.seed, config.device.device)
|
24 |
-
context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, config.tokenizer, config.models["clip"], config.device.device)
|
25 |
-
latents = initialize_latents(input_image, config.diffusion.strength, generator, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps)
|
26 |
-
images = run_diffusion(latents, context, config.diffusion.do_cfg, config.diffusion.cfg_scale, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps, generator)
|
27 |
-
return postprocess_images(images)
|
28 |
-
|
29 |
-
def validate_strength(strength):
|
30 |
-
if not 0 < strength <= 1:
|
31 |
-
raise ValueError("Strength must be between 0 and 1")
|
32 |
-
|
33 |
-
def initialize_generator(seed, device):
|
34 |
-
generator = torch.Generator(device=device)
|
35 |
-
if seed is None:
|
36 |
-
generator.seed()
|
37 |
-
else:
|
38 |
-
generator.manual_seed(seed)
|
39 |
-
return generator
|
40 |
-
|
41 |
-
def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
|
42 |
-
clip.to(device)
|
43 |
-
if do_cfg:
|
44 |
-
cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
|
45 |
-
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
|
46 |
-
cond_context = clip(cond_tokens)
|
47 |
-
uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
|
48 |
-
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
|
49 |
-
uncond_context = clip(uncond_tokens)
|
50 |
-
context = torch.cat([cond_context, uncond_context])
|
51 |
-
else:
|
52 |
-
tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
|
53 |
-
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
|
54 |
-
context = clip(tokens)
|
55 |
-
return context
|
56 |
-
|
57 |
-
def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps):
|
58 |
-
if input_image is None:
|
59 |
-
# Initialize with random noise
|
60 |
-
latents = torch.randn((1, 4, 64, 64), generator=generator, device=device)
|
61 |
-
else:
|
62 |
-
# Initialize with encoded input image
|
63 |
-
latents = encode_image(input_image, models, device)
|
64 |
-
# Add noise based on strength
|
65 |
-
noise = torch.randn_like(latents, generator=generator)
|
66 |
-
latents = (1 - strength) * latents + strength * noise
|
67 |
-
return latents
|
68 |
-
|
69 |
-
def preprocess_image(input_image):
|
70 |
-
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
|
71 |
-
input_image_tensor = np.array(input_image_tensor)
|
72 |
-
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
|
73 |
-
input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
|
74 |
-
input_image_tensor = input_image_tensor.unsqueeze(0)
|
75 |
-
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
|
76 |
-
return input_image_tensor
|
77 |
-
|
78 |
-
def get_sampler(sampler_name, generator, n_inference_steps):
|
79 |
-
if sampler_name == "ddpm":
|
80 |
-
sampler = DDPMSampler(generator)
|
81 |
-
sampler.set_inference_timesteps(n_inference_steps)
|
82 |
-
else:
|
83 |
-
raise ValueError(f"Unknown sampler value {sampler_name}.")
|
84 |
-
return sampler
|
85 |
-
|
86 |
-
def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
|
87 |
-
diffusion = models["diffusion"]
|
88 |
-
diffusion.to(device)
|
89 |
-
sampler = get_sampler(sampler_name, generator, n_inference_steps)
|
90 |
-
timesteps = tqdm(sampler.timesteps)
|
91 |
-
for timestep in timesteps:
|
92 |
-
time_embedding = get_time_embedding(timestep).to(device)
|
93 |
-
model_input = latents.repeat(2, 1, 1, 1) if do_cfg else latents
|
94 |
-
model_output = diffusion(model_input, context, time_embedding)
|
95 |
-
if do_cfg:
|
96 |
-
output_cond, output_uncond = model_output.chunk(2)
|
97 |
-
model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
|
98 |
-
latents = sampler.step(timestep, latents, model_output)
|
99 |
-
decoder = models["decoder"]
|
100 |
-
decoder.to(device)
|
101 |
-
images = decoder(latents)
|
102 |
-
return images
|
103 |
-
|
104 |
-
def postprocess_images(images):
|
105 |
-
images = rescale(images, (-1, 1), (0, 255), clamp=True)
|
106 |
-
images = images.permute(0, 2, 3, 1)
|
107 |
-
images = images.to("cpu", torch.uint8).numpy()
|
108 |
-
return images[0]
|
109 |
-
|
110 |
-
def rescale(x, old_range, new_range, clamp=False):
|
111 |
-
old_min, old_max = old_range
|
112 |
-
new_min, new_max = new_range
|
113 |
-
x -= old_min
|
114 |
-
x *= (new_max - new_min) / (old_max - old_min)
|
115 |
-
x += new_min
|
116 |
-
if clamp:
|
117 |
-
x = x.clamp(new_min, new_max)
|
118 |
-
return x
|
119 |
-
|
120 |
-
def get_time_embedding(timestep):
|
121 |
-
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
|
122 |
-
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
123 |
-
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|