#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np import torch import torch.nn as nn import math import torch import os import torch import torch.nn as nn import numpy as np import pickle from dataclasses import dataclass, asdict import json from torch.utils.data import DataLoader # Normalization Layer for Conv2D class Norm(nn.Module): def __init__(self, num_channels, num_groups=4): super(Norm, self).__init__() self.norm = nn.GroupNorm(num_groups, num_channels) def forward(self, x): return self.norm(x) # Encoder using Conv2D class Encoder(nn.Module): def __init__(self, latent_dim=3): super(Encoder, self).__init__() self.conv_layers = nn.Sequential( # Input: (batch_size, 1, 256, 256) nn.Conv2d(1, 32, kernel_size=2, stride=2, padding=0), # (batch_size, 64, 128, 128) nn.GELU(), Norm(32), nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0), # (batch_size, 128, 64, 64) nn.GELU(), Norm(64), nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0), # (batch_size, 256, 32, 32) nn.GELU(), Norm(128), nn.Conv2d(128, 256, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 16, 16) nn.GELU(), Norm(256), nn.Conv2d(256, 512, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 8, 8) nn.GELU(), Norm(512), ) self.flatten = nn.Flatten() self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim) self.fc_log_var = nn.Linear(512 * 4 * 4, latent_dim) def forward(self, x): x = self.conv_layers(x) x = self.flatten(x) mean = self.fc_mean(x) log_var = self.fc_log_var(x) return mean, log_var class Decoder(nn.Module): def __init__(self, latent_dim=3): super(Decoder, self).__init__() # Fully connected layer to transform the latent vector back to the shape (batch_size, 512, 8, 8) self.fc = nn.Linear(latent_dim, 512 * 4 * 4) self.deconv_layers = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(512, 256, kernel_size=1), nn.GELU(), Norm(256), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(256, 128, kernel_size=1), nn.GELU(), Norm(128), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(128, 64, kernel_size=1), nn.GELU(), Norm(64), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(64, 32, kernel_size=1), nn.GELU(), Norm(32), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(32, 1, kernel_size=1), nn.ReLU() ) def forward(self, z): # Transform the latent vector to match the shape of the feature maps x = self.fc(z) x = x.view(-1, 512, 4, 4) # Reshape to (batch_size, 512, 4, 4) x = self.deconv_layers(x) return x class Propagator_concat(nn.Module): """ Takes in (z(t), tau, alpha) and outputs z(t+tau) """ def __init__(self, latent_dim, feats=[16, 32, 64, 32, 16]): """ Initialize the propagator network. Input : (z(t), tau) Output: z(t+tau) """ super(Propagator_concat, self).__init__() self._net = nn.Sequential( nn.Linear(latent_dim + 2, feats[0]), # 1 is for tau; more params will increase this nn.GELU(), nn.Linear(feats[0], feats[1]), nn.GELU(), nn.Linear(feats[1], feats[2]), nn.GELU(), nn.Linear(feats[2], feats[3]), nn.GELU(), nn.Linear(feats[3], feats[4]), nn.GELU(), nn.Linear(feats[4], latent_dim), ) def forward(self, z, tau, alpha): """ Forward pass of the propagator. Concatenates latent vector z with tau and processes through the network. """ zproj = z.squeeze(1) # Adjust z dimensions if necessary z_ = torch.cat((zproj, tau, alpha), dim=1) # Concatenate z and tau along the last dimension z_tau = self._net(z_) return z_tau, z_ class Model(nn.Module): def __init__(self, encoder, decoder, propagator): super(Model, self).__init__() self.encoder = encoder self.decoder = decoder # decoder for x(t) self.propagator = propagator # used to time march z(t) to z(t+tau) def reparameterization(self, mean, var): epsilon = torch.randn_like(var) z = mean + var * epsilon return z def forward(self, x, tau, alpha): mean, log_var = self.encoder(x) z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # Update small fcnn to get z(t+tau) from z(t) z_tau, z_ = self.propagator(z, tau, alpha) # Reconstruction x_hat = self.decoder(z) # Reconstruction of x(t) x_hat_tau = self.decoder(z_tau) return x_hat, x_hat_tau, mean, log_var, z_tau, z_ def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var): """ Compute the VAE loss components. :param x: Original input :param x_tau: Future input (ground truth) :param x_hat: Reconstructed x(t) :param x_hat_tau: Predicted x(t+tau) :param mean: Mean of the latent distribution :param log_var: Log variance of the latent distribution :return: reconstruction_loss1, reconstruction_loss2, KLD """ reconstruction_loss1 = nn.MSELoss()(x, x_hat) # Reconstruction loss for x(t) reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau) # Prediction loss for x(t+tau) # Kullback-Leibler Divergence KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1)) # Updated dim return reconstruction_loss1, reconstruction_loss2, KLD