|
''' |
|
THis gile is to contain the DDPM implementation modularized for loading, prediciton and training. |
|
''' |
|
|
|
from torch import nn |
|
import math |
|
import torch |
|
from utils import forward_diffusion_sample, sample_timestep, sample_plot_image |
|
import torch.nn.functional as F |
|
from attn_utils import SelfAttention, CBAM, Block_CBAM |
|
|
|
class Block(nn.Module): |
|
def __init__(self, in_ch, out_ch, time_emb_dim, up=False): |
|
super().__init__() |
|
self.time_mlp = nn.Linear(time_emb_dim, out_ch) |
|
if up: |
|
|
|
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1) |
|
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) |
|
else: |
|
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
|
self.transform = nn.Conv2d(out_ch, out_ch, 4,2,1) |
|
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
|
self.relu = nn.ReLU() |
|
self.batch_norm1 = nn.BatchNorm2d(out_ch) |
|
self.batch_norm2 = nn.BatchNorm2d(out_ch) |
|
|
|
def forward(self, x, t, ): |
|
h = self.batch_norm1(self.relu(self.conv1(x))) |
|
time_emb = self.relu(self.time_mlp(t)) |
|
time_emb = time_emb[(..., ) + (None, ) * 2] |
|
h = h + time_emb |
|
h = self.batch_norm2(self.relu(self.conv2(h))) |
|
return self.transform(h) |
|
|
|
class PositionEmbeddings(nn.Module): |
|
def __init__(self,dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, time): |
|
device = time.device |
|
half_dim = self.dim // 2 |
|
embeddings = math.log(10000) / (half_dim - 1) |
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
|
embeddings = time[:, None] * embeddings[None, :] |
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
|
return embeddings |
|
|
|
|
|
|
|
class SimpleUnet(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
image_channels = 3 |
|
down_channels = (64, 128, 256, 512, 1024) |
|
up_channels = (1024, 512, 256, 128, 64) |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
out_dim = 3 |
|
time_emb_dim = 32 |
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
PositionEmbeddings(time_emb_dim), |
|
nn.Linear(time_emb_dim, time_emb_dim), |
|
nn.ReLU() |
|
) |
|
|
|
|
|
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) |
|
self.down_blocks = nn.ModuleList([ |
|
Block(down_channels[i], down_channels[i+1], time_emb_dim) |
|
for i in range(len(down_channels)-1) |
|
]) |
|
self.up_blocks = nn.ModuleList([ |
|
Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True) |
|
for i in range(len(up_channels)-1) |
|
]) |
|
|
|
|
|
self.output = nn.Conv2d(up_channels[-1], out_dim, 1) |
|
|
|
def forward(self, x, timestep): |
|
t = self.time_mlp(timestep) |
|
x = self.conv0(x) |
|
residual_inputs = [] |
|
for down in self.down_blocks: |
|
x = down(x, t) |
|
residual_inputs.append(x) |
|
for up in self.up_blocks: |
|
residual_x = residual_inputs.pop() |
|
x = torch.cat((x, residual_x), dim=1) |
|
x = up(x, t) |
|
return self.output(x) |
|
|
|
@torch.no_grad() |
|
def sample(self, noise): |
|
""" |
|
Generate an image by denoising a given noise tensor using the reverse diffusion process. |
|
|
|
Args: |
|
noise (torch.Tensor): Initial noise tensor (e.g., sampled from a Gaussian distribution). |
|
|
|
Returns: |
|
torch.Tensor: Denoised image. |
|
""" |
|
img = noise |
|
T = self.num_timesteps |
|
stepsize = 1 |
|
|
|
|
|
for i in range(0, T)[::-1]: |
|
t = torch.full((noise.size(0),), i, device=noise.device, dtype=torch.long) |
|
img = sample_timestep(self, img, t) |
|
img = torch.clamp(img, -1.0, 1.0) |
|
|
|
return img |
|
|
|
def get_loss(self, x_0, t): |
|
x_noisy, noise = forward_diffusion_sample(x_0, t, self.device) |
|
noise_pred = self(x_noisy, t) |
|
return F.l1_loss(noise, noise_pred) |
|
|
|
def train(self, dataloader, BATCH_SIZE=64,T=300, EPOCHS=50, verbose=True): |
|
from torch.optim import Adam |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.to(device) |
|
optimizer = Adam(self.parameters(), lr=0.001) |
|
epochs = EPOCHS |
|
|
|
for epoch in range(epochs): |
|
for step, batch in enumerate(dataloader): |
|
optimizer.zero_grad() |
|
|
|
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long() |
|
loss = self.get_loss(self, batch[0], t) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if verbose: |
|
if epoch % 5 == 0 and step % 150 == 0: |
|
print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ") |
|
sample_plot_image(self) |
|
|
|
def test(): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleUnetWSelfAttn(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
image_channels = 3 |
|
down_channels = (64, 128, 256, 512, 1024) |
|
up_channels = (1024, 512, 256, 128, 64) |
|
|
|
out_dim = 3 |
|
time_emb_dim = 32 |
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
PositionEmbeddings(time_emb_dim), |
|
nn.Linear(time_emb_dim, time_emb_dim), |
|
nn.ReLU() |
|
) |
|
self.num_timesteps = 300 |
|
|
|
|
|
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) |
|
self.down_blocks = nn.ModuleList([ |
|
Block(down_channels[i], down_channels[i+1], time_emb_dim) |
|
for i in range(len(down_channels)-1) |
|
]) |
|
self.up_blocks = nn.ModuleList([ |
|
Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True) |
|
for i in range(len(up_channels)-1) |
|
]) |
|
|
|
self.self_attention = SelfAttention(down_channels[-1]) |
|
|
|
|
|
|
|
self.output = nn.Conv2d(up_channels[-1], out_dim, 1) |
|
|
|
|
|
|
|
|
|
def forward(self, x, timestep): |
|
self.num_timesteps = timestep |
|
t = self.time_mlp(timestep) |
|
x = self.conv0(x) |
|
residual_inputs = [] |
|
for down in self.down_blocks: |
|
x = down(x, t) |
|
residual_inputs.append(x) |
|
|
|
x = self.self_attention(x) |
|
|
|
for up in self.up_blocks: |
|
residual_x = residual_inputs.pop() |
|
x = torch.cat((x, residual_x), dim=1) |
|
x = up(x, t) |
|
return self.output(x) |
|
|
|
@torch.no_grad() |
|
def sample(self, noise): |
|
""" |
|
Generate an image by denoising a given noise tensor using the reverse diffusion process. |
|
|
|
Args: |
|
noise (torch.Tensor): Initial noise tensor (e.g., sampled from a Gaussian distribution). |
|
|
|
Returns: |
|
torch.Tensor: Denoised image. |
|
""" |
|
img = noise |
|
T = self.num_timesteps |
|
stepsize = 1 |
|
print(noise.device) |
|
|
|
|
|
for i in range(T - 1, -1, -1): |
|
t = torch.full((noise.size(0),), i, device=noise.device, dtype=torch.long) |
|
img = sample_timestep(self, img, t) |
|
img = torch.clamp(img, -1.0, 1.0) |
|
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleUnetWCBAM(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
image_channels = 3 |
|
down_channels = (64, 128, 256, 512, 1024) |
|
up_channels = (1024, 512, 256, 128, 64) |
|
|
|
out_dim = 3 |
|
time_emb_dim = 32 |
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
PositionEmbeddings(time_emb_dim), |
|
nn.Linear(time_emb_dim, time_emb_dim), |
|
nn.ReLU() |
|
) |
|
self.num_timesteps = 300 |
|
|
|
|
|
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) |
|
self.down_blocks = nn.ModuleList([ |
|
Block_CBAM(down_channels[i], down_channels[i+1], time_emb_dim) |
|
for i in range(len(down_channels)-1) |
|
]) |
|
self.up_blocks = nn.ModuleList([ |
|
Block_CBAM(up_channels[i], up_channels[i+1], time_emb_dim, up=True) |
|
for i in range(len(up_channels)-1) |
|
]) |
|
|
|
self.self_attention = SelfAttention(down_channels[-1]) |
|
|
|
|
|
|
|
self.output = nn.Conv2d(up_channels[-1], out_dim, 1) |
|
|
|
|
|
|
|
|
|
def forward(self, x, timestep): |
|
self.num_timesteps = timestep |
|
t = self.time_mlp(timestep) |
|
x = self.conv0(x) |
|
residual_inputs = [] |
|
for down in self.down_blocks: |
|
x = down(x, t) |
|
residual_inputs.append(x) |
|
|
|
x = self.self_attention(x) |
|
|
|
for up in self.up_blocks: |
|
residual_x = residual_inputs.pop() |
|
x = torch.cat((x, residual_x), dim=1) |
|
x = up(x, t) |
|
return self.output(x) |
|
|
|
@torch.no_grad() |
|
def sample(self, noise): |
|
""" |
|
Generate an image by denoising a given noise tensor using the reverse diffusion process. |
|
|
|
Args: |
|
noise (torch.Tensor): Initial noise tensor (e.g., sampled from a Gaussian distribution). |
|
|
|
Returns: |
|
torch.Tensor: Denoised image. |
|
""" |
|
img = noise |
|
T = self.num_timesteps |
|
stepsize = 1 |
|
print(noise.device) |
|
|
|
|
|
for i in range(T - 1, -1, -1): |
|
t = torch.full((noise.size(0),), i, device=noise.device, dtype=torch.long) |
|
img = sample_timestep(self, img, t) |
|
img = torch.clamp(img, -1.0, 1.0) |
|
|
|
return img |
|
|