Flexi-Propagator / model_adv_dif.py
Khalid Rafiq
Add all required modules and requirements.txt
ab72d17
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import torch
import torch.nn as nn
import math
import torch
import os
import torch
import torch.nn as nn
import numpy as np
import pickle
from dataclasses import dataclass, asdict
import json
from torch.utils.data import DataLoader
# Normalization Layer for Conv2D
class Norm(nn.Module):
def __init__(self, num_channels, num_groups=4):
super(Norm, self).__init__()
self.norm = nn.GroupNorm(num_groups, num_channels)
def forward(self, x):
return self.norm(x)
# Encoder using Conv2D
class Encoder(nn.Module):
def __init__(self, latent_dim=3):
super(Encoder, self).__init__()
self.conv_layers = nn.Sequential(
# Input: (batch_size, 1, 256, 256)
nn.Conv2d(1, 32, kernel_size=2, stride=2, padding=0), # (batch_size, 64, 128, 128)
nn.GELU(),
Norm(32),
nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0), # (batch_size, 128, 64, 64)
nn.GELU(),
Norm(64),
nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0), # (batch_size, 256, 32, 32)
nn.GELU(),
Norm(128),
nn.Conv2d(128, 256, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 16, 16)
nn.GELU(),
Norm(256),
nn.Conv2d(256, 512, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 8, 8)
nn.GELU(),
Norm(512),
)
self.flatten = nn.Flatten()
self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
self.fc_log_var = nn.Linear(512 * 4 * 4, latent_dim)
def forward(self, x):
x = self.conv_layers(x)
x = self.flatten(x)
mean = self.fc_mean(x)
log_var = self.fc_log_var(x)
return mean, log_var
class Decoder(nn.Module):
def __init__(self, latent_dim=3):
super(Decoder, self).__init__()
# Fully connected layer to transform the latent vector back to the shape (batch_size, 512, 8, 8)
self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
self.deconv_layers = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(512, 256, kernel_size=1),
nn.GELU(),
Norm(256),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(256, 128, kernel_size=1),
nn.GELU(),
Norm(128),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(128, 64, kernel_size=1),
nn.GELU(),
Norm(64),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(64, 32, kernel_size=1),
nn.GELU(),
Norm(32),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(32, 1, kernel_size=1),
nn.ReLU()
)
def forward(self, z):
# Transform the latent vector to match the shape of the feature maps
x = self.fc(z)
x = x.view(-1, 512, 4, 4) # Reshape to (batch_size, 512, 4, 4)
x = self.deconv_layers(x)
return x
class Propagator_concat(nn.Module):
"""
Takes in (z(t), tau, alpha) and outputs z(t+tau)
"""
def __init__(self, latent_dim, feats=[16, 32, 64, 32, 16]):
"""
Initialize the propagator network.
Input : (z(t), tau)
Output: z(t+tau)
"""
super(Propagator_concat, self).__init__()
self._net = nn.Sequential(
nn.Linear(latent_dim + 2, feats[0]), # 1 is for tau; more params will increase this
nn.GELU(),
nn.Linear(feats[0], feats[1]),
nn.GELU(),
nn.Linear(feats[1], feats[2]),
nn.GELU(),
nn.Linear(feats[2], feats[3]),
nn.GELU(),
nn.Linear(feats[3], feats[4]),
nn.GELU(),
nn.Linear(feats[4], latent_dim),
)
def forward(self, z, tau, alpha):
"""
Forward pass of the propagator.
Concatenates latent vector z with tau and processes through the network.
"""
zproj = z.squeeze(1) # Adjust z dimensions if necessary
z_ = torch.cat((zproj, tau, alpha), dim=1) # Concatenate z and tau along the last dimension
z_tau = self._net(z_)
return z_tau, z_
class Model(nn.Module):
def __init__(self, encoder, decoder, propagator):
super(Model, self).__init__()
self.encoder = encoder
self.decoder = decoder # decoder for x(t)
self.propagator = propagator # used to time march z(t) to z(t+tau)
def reparameterization(self, mean, var):
epsilon = torch.randn_like(var)
z = mean + var * epsilon
return z
def forward(self, x, tau, alpha):
mean, log_var = self.encoder(x)
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
# Update small fcnn to get z(t+tau) from z(t)
z_tau, z_ = self.propagator(z, tau, alpha)
# Reconstruction
x_hat = self.decoder(z) # Reconstruction of x(t)
x_hat_tau = self.decoder(z_tau)
return x_hat, x_hat_tau, mean, log_var, z_tau, z_
def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var):
"""
Compute the VAE loss components.
:param x: Original input
:param x_tau: Future input (ground truth)
:param x_hat: Reconstructed x(t)
:param x_hat_tau: Predicted x(t+tau)
:param mean: Mean of the latent distribution
:param log_var: Log variance of the latent distribution
:return: reconstruction_loss1, reconstruction_loss2, KLD
"""
reconstruction_loss1 = nn.MSELoss()(x, x_hat) # Reconstruction loss for x(t)
reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau) # Prediction loss for x(t+tau)
# Kullback-Leibler Divergence
KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1)) # Updated dim
return reconstruction_loss1, reconstruction_loss2, KLD