torinriley commited on
Commit
9937b2d
·
verified ·
1 Parent(s): 2f6f73f

Delete src

Browse files
src/__init__.py DELETED
File without changes
src/__pycache__/attention.cpython-312.pyc DELETED
Binary file (4.68 kB)
 
src/__pycache__/clip.cpython-312.pyc DELETED
Binary file (4 kB)
 
src/__pycache__/config.cpython-312.pyc DELETED
Binary file (3.39 kB)
 
src/__pycache__/ddpm.cpython-312.pyc DELETED
Binary file (6.45 kB)
 
src/__pycache__/decoder.cpython-312.pyc DELETED
Binary file (4.9 kB)
 
src/__pycache__/diffusion.cpython-312.pyc DELETED
Binary file (14.2 kB)
 
src/__pycache__/encoder.cpython-312.pyc DELETED
Binary file (2.53 kB)
 
src/__pycache__/model_converter.cpython-312.pyc DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:65ec381ffd1ecb7e843d62f4c98aab8630d90d1273f051e6761d1b9837281628
3
- size 170116
 
 
 
 
src/__pycache__/model_loader.cpython-312.pyc DELETED
Binary file (1.84 kB)
 
src/__pycache__/pipeline.cpython-312.pyc DELETED
Binary file (8.03 kB)
 
src/attention.py DELETED
@@ -1,69 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- import math
5
-
6
- class SelfAttention(nn.Module):
7
- def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
8
- super().__init__()
9
- self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
10
- self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
11
- self.n_heads = n_heads
12
- self.d_head = d_embed // n_heads
13
-
14
- def forward(self, x, causal_mask=False):
15
- input_shape = x.shape
16
- batch_size, sequence_length, d_embed = input_shape
17
- interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
18
-
19
- q, k, v = self.in_proj(x).chunk(3, dim=-1)
20
- q = q.view(interim_shape).transpose(1, 2)
21
- k = k.view(interim_shape).transpose(1, 2)
22
- v = v.view(interim_shape).transpose(1, 2)
23
-
24
- weight = q @ k.transpose(-1, -2)
25
-
26
- if causal_mask:
27
- mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
28
- weight.masked_fill_(mask, -torch.inf)
29
-
30
- weight /= math.sqrt(self.d_head)
31
- weight = F.softmax(weight, dim=-1)
32
- output = weight @ v
33
- output = output.transpose(1, 2).reshape(input_shape)
34
- output = self.out_proj(output)
35
-
36
- return output
37
-
38
- class CrossAttention(nn.Module):
39
- def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
40
- super().__init__()
41
- self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
42
- self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
43
- self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
44
- self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
45
- self.n_heads = n_heads
46
- self.d_head = d_embed // n_heads
47
-
48
- def forward(self, x, y):
49
- input_shape = x.shape
50
- batch_size, sequence_length, d_embed = input_shape
51
- interim_shape = (batch_size, -1, self.n_heads, self.d_head)
52
-
53
- q = self.q_proj(x)
54
- k = self.k_proj(y)
55
- v = self.v_proj(y)
56
-
57
- q = q.view(interim_shape).transpose(1, 2)
58
- k = k.view(interim_shape).transpose(1, 2)
59
- v = v.view(interim_shape).transpose(1, 2)
60
-
61
- weight = q @ k.transpose(-1, -2)
62
- weight /= math.sqrt(self.d_head)
63
- weight = F.softmax(weight, dim=-1)
64
- output = weight @ v
65
- output = output.transpose(1, 2).contiguous().view(input_shape)
66
- output = self.out_proj(output)
67
-
68
- return output
69
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clip.py DELETED
@@ -1,54 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
- from attention import SelfAttention
5
-
6
- class CLIPEmbedding(nn.Module):
7
- def __init__(self, n_vocab: int, n_embd: int, n_token: int):
8
- super().__init__()
9
- self.token_embedding = nn.Embedding(n_vocab, n_embd)
10
- self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
11
-
12
- def forward(self, tokens):
13
- x = self.token_embedding(tokens)
14
- x += self.position_embedding
15
- return x
16
-
17
- class CLIPLayer(nn.Module):
18
- def __init__(self, n_head: int, n_embd: int):
19
- super().__init__()
20
- self.layernorm_1 = nn.LayerNorm(n_embd)
21
- self.attention = SelfAttention(n_head, n_embd)
22
- self.layernorm_2 = nn.LayerNorm(n_embd)
23
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
24
- self.linear_2 = nn.Linear(4 * n_embd, n_embd)
25
- self.activation = nn.GELU()
26
-
27
- def forward(self, x):
28
- residue = x
29
- x = self.layernorm_1(x)
30
- x = self.attention(x, causal_mask=True)
31
- x += residue
32
-
33
- residue = x
34
- x = self.layernorm_2(x)
35
- x = self.linear_1(x)
36
- x = self.activation(x)
37
- x = self.linear_2(x)
38
- x += residue
39
-
40
- return x
41
-
42
- class CLIP(nn.Module):
43
- def __init__(self):
44
- super().__init__()
45
- self.embedding = CLIPEmbedding(49408, 768, 77)
46
- self.layers = nn.ModuleList([CLIPLayer(12, 768) for _ in range(12)])
47
- self.layernorm = nn.LayerNorm(768)
48
-
49
- def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
50
- state = self.embedding(tokens)
51
- for layer in self.layers:
52
- state = layer(state)
53
- output = self.layernorm(state)
54
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config.py DELETED
@@ -1,72 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Optional, Dict, Any
3
- import torch
4
-
5
- @dataclass
6
- class ModelConfig:
7
- # Image dimensions
8
- width: int = 512
9
- height: int = 512
10
- latents_width: int = 64 # width // 8
11
- latents_height: int = 64 # height // 8
12
-
13
- # Model architecture parameters
14
- n_embd: int = 1280
15
- n_head: int = 8
16
- d_context: int = 768
17
-
18
- # UNet parameters
19
- n_time: int = 1280
20
- n_channels: int = 4
21
- n_residual_blocks: int = 2
22
-
23
- # Attention parameters
24
- attention_heads: int = 8
25
- attention_dim: int = 1280
26
-
27
- @dataclass
28
- class DiffusionConfig:
29
- # Sampling parameters
30
- n_inference_steps: int = 50
31
- guidance_scale: float = 7.5
32
- strength: float = 0.8
33
-
34
- # Sampler configuration
35
- sampler_name: str = "ddpm"
36
- beta_start: float = 0.00085
37
- beta_end: float = 0.0120
38
- beta_schedule: str = "linear"
39
-
40
- # Conditioning parameters
41
- do_cfg: bool = True
42
- cfg_scale: float = 7.5
43
-
44
- @dataclass
45
- class DeviceConfig:
46
- device: Optional[str] = None
47
- idle_device: Optional[str] = None
48
-
49
- def __post_init__(self):
50
- if self.device is None:
51
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
- if self.idle_device is None:
53
- self.idle_device = "cpu"
54
-
55
- @dataclass
56
- class Config:
57
- model: ModelConfig = field(default_factory=ModelConfig)
58
- diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
59
- device: DeviceConfig = field(default_factory=DeviceConfig)
60
-
61
- # Additional settings
62
- seed: Optional[int] = None
63
- tokenizer: Optional[Any] = None
64
- models: Dict[str, Any] = field(default_factory=dict)
65
-
66
- def __post_init__(self):
67
- # Update latent dimensions based on image dimensions
68
- self.model.latents_width = self.model.width // 8
69
- self.model.latents_height = self.model.height // 8
70
-
71
- # Default configuration instance
72
- default_config = Config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ddpm.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
- class DDPMSampler:
5
-
6
- def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
7
- self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
8
- self.alphas = 1.0 - self.betas
9
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
10
- self.one = torch.tensor(1.0)
11
-
12
- self.generator = generator
13
- self.num_train_timesteps = num_training_steps
14
- self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
15
-
16
- def set_inference_timesteps(self, num_inference_steps=50):
17
- self.num_inference_steps = num_inference_steps
18
- step_ratio = self.num_train_timesteps // self.num_inference_steps
19
- inference_timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
20
- self.timesteps = torch.from_numpy(inference_timesteps)
21
-
22
- def _get_previous_timestep(self, timestep: int) -> int:
23
- return timestep - self.num_train_timesteps // self.num_inference_steps
24
-
25
- def _get_variance(self, timestep: int) -> torch.Tensor:
26
- prev_timestep = self._get_previous_timestep(timestep)
27
- alpha_prod_t = self.alphas_cumprod[timestep]
28
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
29
- current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
30
- variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
31
- return torch.clamp(variance, min=1e-20)
32
-
33
- def set_strength(self, strength=1):
34
- start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
35
- self.timesteps = self.timesteps[start_step:]
36
- self.start_step = start_step
37
-
38
- def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
39
- prev_timestep = self._get_previous_timestep(timestep)
40
- alpha_prod_t = self.alphas_cumprod[timestep]
41
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
42
- beta_prod_t = 1 - alpha_prod_t
43
- beta_prod_t_prev = 1 - alpha_prod_t_prev
44
- current_alpha_t = alpha_prod_t / alpha_prod_t_prev
45
- current_beta_t = 1 - current_alpha_t
46
-
47
- pred_original_sample = (latents - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
48
- pred_original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
49
- current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
50
- pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
51
-
52
- variance = 0
53
- if timestep > 0:
54
- device = model_output.device
55
- noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
56
- variance = (self._get_variance(timestep) ** 0.5) * noise
57
-
58
- return pred_prev_sample + variance
59
-
60
- def add_noise(self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
61
- alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
62
- timesteps = timesteps.to(original_samples.device)
63
-
64
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
65
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
66
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
67
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
68
-
69
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
70
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
71
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
72
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
73
-
74
- noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
75
- return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/decoder.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from attention import SelfAttention
5
-
6
- class VAE_AttentionBlock(nn.Module):
7
- def __init__(self, channels):
8
- super().__init__()
9
- self.groupnorm = nn.GroupNorm(32, channels)
10
- self.attention = SelfAttention(1, channels)
11
-
12
- def forward(self, x):
13
- residue = x
14
- x = self.groupnorm(x)
15
- n, c, h, w = x.shape
16
- x = x.view((n, c, h * w)).transpose(-1, -2)
17
- x = self.attention(x)
18
- x = x.transpose(-1, -2).view((n, c, h, w))
19
- return x + residue
20
-
21
- class VAE_ResidualBlock(nn.Module):
22
- def __init__(self, in_channels, out_channels):
23
- super().__init__()
24
- self.groupnorm_1 = nn.GroupNorm(32, in_channels)
25
- self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
26
- self.groupnorm_2 = nn.GroupNorm(32, out_channels)
27
- self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
28
- self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
29
-
30
- def forward(self, x):
31
- residue = x
32
- x = self.groupnorm_1(x)
33
- x = F.silu(x)
34
- x = self.conv_1(x)
35
- x = self.groupnorm_2(x)
36
- x = F.silu(x)
37
- x = self.conv_2(x)
38
- return x + self.residual_layer(residue)
39
-
40
- class VAE_Decoder(nn.Sequential):
41
- def __init__(self):
42
- super().__init__(
43
- nn.Conv2d(4, 4, kernel_size=1, padding=0),
44
- nn.Conv2d(4, 512, kernel_size=3, padding=1),
45
- VAE_ResidualBlock(512, 512),
46
- VAE_AttentionBlock(512),
47
- VAE_ResidualBlock(512, 512),
48
- VAE_ResidualBlock(512, 512),
49
- VAE_ResidualBlock(512, 512),
50
- VAE_ResidualBlock(512, 512),
51
- nn.Upsample(scale_factor=2),
52
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
53
- VAE_ResidualBlock(512, 512),
54
- VAE_ResidualBlock(512, 512),
55
- VAE_ResidualBlock(512, 512),
56
- nn.Upsample(scale_factor=2),
57
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
58
- VAE_ResidualBlock(512, 256),
59
- VAE_ResidualBlock(256, 256),
60
- VAE_ResidualBlock(256, 256),
61
- nn.Upsample(scale_factor=2),
62
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
63
- VAE_ResidualBlock(256, 128),
64
- VAE_ResidualBlock(128, 128),
65
- VAE_ResidualBlock(128, 128),
66
- nn.GroupNorm(32, 128),
67
- nn.SiLU(),
68
- nn.Conv2d(128, 3, kernel_size=3, padding=1),
69
- )
70
-
71
- def forward(self, x):
72
- x /= 0.18215
73
- for module in self:
74
- x = module(x)
75
- return x
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/demo.py DELETED
@@ -1,48 +0,0 @@
1
- import model_loader
2
- import pipeline
3
- from PIL import Image
4
- from pathlib import Path
5
- from transformers import CLIPTokenizer
6
- import torch
7
- from config import Config, default_config, DeviceConfig
8
-
9
- # Device configuration
10
- ALLOW_CUDA = False
11
- ALLOW_MPS = False
12
-
13
- device = "cpu"
14
- if torch.cuda.is_available() and ALLOW_CUDA:
15
- device = "cuda"
16
- elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
17
- device = "mps"
18
- print(f"Using device: {device}")
19
-
20
- # Initialize configuration
21
- config = Config(
22
- device=DeviceConfig(device=device),
23
- seed=42,
24
- tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
25
- )
26
-
27
- # Update diffusion parameters
28
- config.diffusion.strength = 0.75
29
- config.diffusion.cfg_scale = 8.0
30
- config.diffusion.n_inference_steps = 50
31
-
32
- # Load models with SE blocks enabled
33
- model_file = "data/v1-5-pruned-emaonly.ckpt"
34
- config.models = model_loader.load_models(model_file, device, use_se=True)
35
-
36
- # Generate image
37
- prompt = "A ultra sharp photorealtici painting of a futuristic cityscape at night with neon lights and flying cars"
38
- uncond_prompt = ""
39
-
40
- output_image = pipeline.generate(
41
- prompt=prompt,
42
- uncond_prompt=uncond_prompt,
43
- config=config
44
- )
45
-
46
- # Save output
47
- output_image = Image.fromarray(output_image)
48
- output_image.save("output.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/diffusion.py DELETED
@@ -1,187 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from attention import SelfAttention, CrossAttention
5
-
6
- class TimeEmbedding(nn.Module):
7
- def __init__(self, n_embd):
8
- super().__init__()
9
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
10
- self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
11
-
12
- def forward(self, x):
13
- x = F.silu(self.linear_1(x))
14
- return self.linear_2(x)
15
-
16
- class SqueezeExcitation(nn.Module):
17
- def __init__(self, channels, reduction=16):
18
- super().__init__()
19
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
20
- self.fc = nn.Sequential(
21
- nn.Linear(channels, channels // reduction, bias=False),
22
- nn.ReLU(inplace=True),
23
- nn.Linear(channels // reduction, channels, bias=False),
24
- nn.Sigmoid()
25
- )
26
-
27
- def forward(self, x):
28
- b, c, _, _ = x.size()
29
- y = self.avg_pool(x).view(b, c)
30
- y = self.fc(y).view(b, c, 1, 1)
31
- return x * y.expand_as(x)
32
-
33
- class UNET_ResidualBlock(nn.Module):
34
- def __init__(self, in_channels, out_channels, n_time=1280, use_se=False):
35
- super().__init__()
36
- self.groupnorm_feature = nn.GroupNorm(32, in_channels)
37
- self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
38
- self.linear_time = nn.Linear(n_time, out_channels)
39
- self.groupnorm_merged = nn.GroupNorm(32, out_channels)
40
- self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
41
- self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
42
-
43
- # Add Squeeze-Excitation blocks only if use_se is True
44
- self.use_se = use_se
45
- if use_se:
46
- self.se1 = SqueezeExcitation(out_channels)
47
- self.se2 = SqueezeExcitation(out_channels)
48
-
49
- def forward(self, feature, time):
50
- residue = feature
51
- feature = F.silu(self.groupnorm_feature(feature))
52
- feature = self.conv_feature(feature)
53
- if self.use_se:
54
- feature = self.se1(feature) # Apply SE after first conv
55
-
56
- time = self.linear_time(F.silu(time))
57
- merged = feature + time.unsqueeze(-1).unsqueeze(-1)
58
- merged = F.silu(self.groupnorm_merged(merged))
59
- merged = self.conv_merged(merged)
60
- if self.use_se:
61
- merged = self.se2(merged) # Apply SE after second conv
62
-
63
- return merged + self.residual_layer(residue)
64
-
65
- class UNET_AttentionBlock(nn.Module):
66
- def __init__(self, n_head: int, n_embd: int, d_context=768):
67
- super().__init__()
68
- channels = n_head * n_embd
69
- self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
70
- self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
71
- self.layernorm_1 = nn.LayerNorm(channels)
72
- self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
73
- self.layernorm_2 = nn.LayerNorm(channels)
74
- self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
75
- self.layernorm_3 = nn.LayerNorm(channels)
76
- self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
77
- self.linear_geglu_2 = nn.Linear(4 * channels, channels)
78
- self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
79
-
80
- def forward(self, x, context):
81
- residue_long = x
82
- x = self.conv_input(self.groupnorm(x))
83
- n, c, h, w = x.shape
84
- x = x.view((n, c, h * w)).transpose(-1, -2)
85
- residue_short = x
86
- x = self.attention_1(self.layernorm_1(x)) + residue_short
87
- residue_short = x
88
- x = self.attention_2(self.layernorm_2(x), context) + residue_short
89
- residue_short = x
90
- x, gate = self.linear_geglu_1(self.layernorm_3(x)).chunk(2, dim=-1)
91
- x = self.linear_geglu_2(x * F.gelu(gate)) + residue_short
92
- x = x.transpose(-1, -2).view((n, c, h, w))
93
- return self.conv_output(x) + residue_long
94
-
95
- class Upsample(nn.Module):
96
- def __init__(self, channels):
97
- super().__init__()
98
- self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
99
-
100
- def forward(self, x):
101
- return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
102
-
103
- class SwitchSequential(nn.Sequential):
104
- def forward(self, x, context, time):
105
- for layer in self:
106
- if isinstance(layer, UNET_AttentionBlock):
107
- x = layer(x, context)
108
- elif isinstance(layer, UNET_ResidualBlock):
109
- x = layer(x, time)
110
- else:
111
- x = layer(x)
112
- return x
113
-
114
- class UNET(nn.Module):
115
- def __init__(self, use_se=False):
116
- super().__init__()
117
- self.encoders = nn.ModuleList([
118
- SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
119
- SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
120
- SwitchSequential(UNET_ResidualBlock(320, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
121
- SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
122
- SwitchSequential(UNET_ResidualBlock(320, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
123
- SwitchSequential(UNET_ResidualBlock(640, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
124
- SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
125
- SwitchSequential(UNET_ResidualBlock(640, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
126
- SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
127
- SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
128
- SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
129
- SwitchSequential(UNET_ResidualBlock(1280, 1280, use_se=use_se)),
130
- ])
131
-
132
- self.bottleneck = SwitchSequential(
133
- UNET_ResidualBlock(1280, 1280, use_se=use_se),
134
- UNET_AttentionBlock(8, 160),
135
- UNET_ResidualBlock(1280, 1280, use_se=use_se),
136
- )
137
-
138
- self.decoders = nn.ModuleList([
139
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
140
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se)),
141
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), Upsample(1280)),
142
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
143
- SwitchSequential(UNET_ResidualBlock(2560, 1280, use_se=use_se), UNET_AttentionBlock(8, 160)),
144
- SwitchSequential(UNET_ResidualBlock(1920, 1280, use_se=use_se), UNET_AttentionBlock(8, 160), Upsample(1280)),
145
- SwitchSequential(UNET_ResidualBlock(1920, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
146
- SwitchSequential(UNET_ResidualBlock(1280, 640, use_se=use_se), UNET_AttentionBlock(8, 80)),
147
- SwitchSequential(UNET_ResidualBlock(960, 640, use_se=use_se), UNET_AttentionBlock(8, 80), Upsample(640)),
148
- SwitchSequential(UNET_ResidualBlock(960, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
149
- SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
150
- SwitchSequential(UNET_ResidualBlock(640, 320, use_se=use_se), UNET_AttentionBlock(8, 40)),
151
- ])
152
-
153
- def forward(self, x, context, time):
154
- skip_connections = []
155
- for layers in self.encoders:
156
- x = layers(x, context, time)
157
- skip_connections.append(x)
158
-
159
- x = self.bottleneck(x, context, time)
160
-
161
- for layers in self.decoders:
162
- x = torch.cat((x, skip_connections.pop()), dim=1)
163
- x = layers(x, context, time)
164
-
165
- return x
166
-
167
- class UNET_OutputLayer(nn.Module):
168
- def __init__(self, in_channels, out_channels):
169
- super().__init__()
170
- self.groupnorm = nn.GroupNorm(32, in_channels)
171
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
172
-
173
- def forward(self, x):
174
- x = F.silu(self.groupnorm(x))
175
- return self.conv(x)
176
-
177
- class Diffusion(nn.Module):
178
- def __init__(self, use_se=False):
179
- super().__init__()
180
- self.time_embedding = TimeEmbedding(320)
181
- self.unet = UNET(use_se=use_se)
182
- self.final = UNET_OutputLayer(320, 4)
183
-
184
- def forward(self, latent, context, time):
185
- time = self.time_embedding(time)
186
- output = self.unet(latent, context, time)
187
- return self.final(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/encoder.py DELETED
@@ -1,42 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
-
6
- class VAE_Encoder(nn.Sequential):
7
- def __init__(self):
8
- super().__init__(
9
- nn.Conv2d(3, 128, kernel_size=3, padding=1),
10
- VAE_ResidualBlock(128, 128),
11
- VAE_ResidualBlock(128, 128),
12
- nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
13
- VAE_ResidualBlock(128, 256),
14
- VAE_ResidualBlock(256, 256),
15
- nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
16
- VAE_ResidualBlock(256, 512),
17
- VAE_ResidualBlock(512, 512),
18
- nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
19
- VAE_ResidualBlock(512, 512),
20
- VAE_ResidualBlock(512, 512),
21
- VAE_ResidualBlock(512, 512),
22
- VAE_AttentionBlock(512),
23
- VAE_ResidualBlock(512, 512),
24
- nn.GroupNorm(32, 512),
25
- nn.SiLU(),
26
- nn.Conv2d(512, 8, kernel_size=3, padding=1),
27
- nn.Conv2d(8, 8, kernel_size=1, padding=0),
28
- )
29
-
30
- def forward(self, x, noise):
31
- for module in self:
32
- if getattr(module, 'stride', None) == (2, 2):
33
- x = F.pad(x, (0, 1, 0, 1))
34
- x = module(x)
35
- mean, log_variance = torch.chunk(x, 2, dim=1)
36
- log_variance = torch.clamp(log_variance, -30, 20)
37
- variance = log_variance.exp()
38
- stdev = variance.sqrt()
39
- x = mean + stdev * noise
40
- x *= 0.18215
41
- return x
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model_converter.py DELETED
The diff for this file is too large to render. See raw diff
 
src/model_loader.py DELETED
@@ -1,40 +0,0 @@
1
- from clip import CLIP
2
- from encoder import VAE_Encoder
3
- from decoder import VAE_Decoder
4
- from diffusion import Diffusion
5
-
6
- import model_converter
7
- import torch
8
-
9
- def load_models(ckpt_path, device, use_se=False):
10
- state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
11
-
12
- encoder = VAE_Encoder().to(device)
13
- encoder.load_state_dict(state_dict['encoder'], strict=True)
14
-
15
- decoder = VAE_Decoder().to(device)
16
- decoder.load_state_dict(state_dict['decoder'], strict=True)
17
-
18
- # Initialize diffusion model with SE blocks disabled for loading pre-trained weights
19
- diffusion = Diffusion(use_se=False).to(device)
20
- diffusion.load_state_dict(state_dict['diffusion'], strict=True)
21
-
22
- # If SE blocks are requested, reinitialize the model with them
23
- if use_se:
24
- diffusion = Diffusion(use_se=True).to(device)
25
- # Copy the weights from the loaded model
26
- with torch.no_grad():
27
- for name, param in diffusion.named_parameters():
28
- if 'se' not in name: # Skip SE block parameters
29
- if name in state_dict['diffusion']:
30
- param.copy_(state_dict['diffusion'][name])
31
-
32
- clip = CLIP().to(device)
33
- clip.load_state_dict(state_dict['clip'], strict=True)
34
-
35
- return {
36
- 'clip': clip,
37
- 'encoder': encoder,
38
- 'decoder': decoder,
39
- 'diffusion': diffusion,
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pipeline.py DELETED
@@ -1,123 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from tqdm import tqdm
4
- from ddpm import DDPMSampler
5
- import logging
6
- from config import Config, default_config
7
-
8
- WIDTH = 512
9
- HEIGHT = 512
10
- LATENTS_WIDTH = WIDTH // 8
11
- LATENTS_HEIGHT = HEIGHT // 8
12
-
13
- logging.basicConfig(level=logging.INFO)
14
-
15
- def generate(
16
- prompt,
17
- uncond_prompt=None,
18
- input_image=None,
19
- config: Config = default_config,
20
- ):
21
- with torch.no_grad():
22
- validate_strength(config.diffusion.strength)
23
- generator = initialize_generator(config.seed, config.device.device)
24
- context = encode_prompt(prompt, uncond_prompt, config.diffusion.do_cfg, config.tokenizer, config.models["clip"], config.device.device)
25
- latents = initialize_latents(input_image, config.diffusion.strength, generator, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps)
26
- images = run_diffusion(latents, context, config.diffusion.do_cfg, config.diffusion.cfg_scale, config.models, config.device.device, config.diffusion.sampler_name, config.diffusion.n_inference_steps, generator)
27
- return postprocess_images(images)
28
-
29
- def validate_strength(strength):
30
- if not 0 < strength <= 1:
31
- raise ValueError("Strength must be between 0 and 1")
32
-
33
- def initialize_generator(seed, device):
34
- generator = torch.Generator(device=device)
35
- if seed is None:
36
- generator.seed()
37
- else:
38
- generator.manual_seed(seed)
39
- return generator
40
-
41
- def encode_prompt(prompt, uncond_prompt, do_cfg, tokenizer, clip, device):
42
- clip.to(device)
43
- if do_cfg:
44
- cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
45
- cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
46
- cond_context = clip(cond_tokens)
47
- uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
48
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
49
- uncond_context = clip(uncond_tokens)
50
- context = torch.cat([cond_context, uncond_context])
51
- else:
52
- tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
53
- tokens = torch.tensor(tokens, dtype=torch.long, device=device)
54
- context = clip(tokens)
55
- return context
56
-
57
- def initialize_latents(input_image, strength, generator, models, device, sampler_name, n_inference_steps):
58
- if input_image is None:
59
- # Initialize with random noise
60
- latents = torch.randn((1, 4, 64, 64), generator=generator, device=device)
61
- else:
62
- # Initialize with encoded input image
63
- latents = encode_image(input_image, models, device)
64
- # Add noise based on strength
65
- noise = torch.randn_like(latents, generator=generator)
66
- latents = (1 - strength) * latents + strength * noise
67
- return latents
68
-
69
- def preprocess_image(input_image):
70
- input_image_tensor = input_image.resize((WIDTH, HEIGHT))
71
- input_image_tensor = np.array(input_image_tensor)
72
- input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
73
- input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
74
- input_image_tensor = input_image_tensor.unsqueeze(0)
75
- input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
76
- return input_image_tensor
77
-
78
- def get_sampler(sampler_name, generator, n_inference_steps):
79
- if sampler_name == "ddpm":
80
- sampler = DDPMSampler(generator)
81
- sampler.set_inference_timesteps(n_inference_steps)
82
- else:
83
- raise ValueError(f"Unknown sampler value {sampler_name}.")
84
- return sampler
85
-
86
- def run_diffusion(latents, context, do_cfg, cfg_scale, models, device, sampler_name, n_inference_steps, generator):
87
- diffusion = models["diffusion"]
88
- diffusion.to(device)
89
- sampler = get_sampler(sampler_name, generator, n_inference_steps)
90
- timesteps = tqdm(sampler.timesteps)
91
- for timestep in timesteps:
92
- time_embedding = get_time_embedding(timestep).to(device)
93
- model_input = latents.repeat(2, 1, 1, 1) if do_cfg else latents
94
- model_output = diffusion(model_input, context, time_embedding)
95
- if do_cfg:
96
- output_cond, output_uncond = model_output.chunk(2)
97
- model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
98
- latents = sampler.step(timestep, latents, model_output)
99
- decoder = models["decoder"]
100
- decoder.to(device)
101
- images = decoder(latents)
102
- return images
103
-
104
- def postprocess_images(images):
105
- images = rescale(images, (-1, 1), (0, 255), clamp=True)
106
- images = images.permute(0, 2, 3, 1)
107
- images = images.to("cpu", torch.uint8).numpy()
108
- return images[0]
109
-
110
- def rescale(x, old_range, new_range, clamp=False):
111
- old_min, old_max = old_range
112
- new_min, new_max = new_range
113
- x -= old_min
114
- x *= (new_max - new_min) / (old_max - old_min)
115
- x += new_min
116
- if clamp:
117
- x = x.clamp(new_min, new_max)
118
- return x
119
-
120
- def get_time_embedding(timestep):
121
- freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
122
- x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
123
- return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)