#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np import torch import torch.nn as nn import math import torch # In[2]: def positionalencoding1d(d_model, length): """ :param d_model: dimension of the model :param length: length of positions :return: length*d_model position matrix """ if d_model % 2 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model)) pe = torch.zeros(length, d_model) position = torch.arange(0, length).unsqueeze(1) div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) return pe # In[3]: class Norm(nn.Module): def __init__(self, num_channels, num_groups=4): super(Norm, self).__init__() self.norm = nn.GroupNorm(num_groups, num_channels) def forward(self, x): return self.norm(x.permute(0,2,1)).permute(0,2,1) # In[4]: class Norm_new(nn.Module): def __init__(self, num_channels, num_groups=4): super(Norm_new, self).__init__() self.norm = nn.GroupNorm(num_groups, num_channels) def forward(self, x): if x.dim() == 2: # Reshape to (batch_size, num_channels, 1) x = x.unsqueeze(-1) x = self.norm(x) # Reshape back to (batch_size, num_channels) x = x.squeeze(-1) else: x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) return x class Encoder(nn.Module): def __init__(self, input_dim, latent_dim=2, feats=[512, 256, 128, 64, 32]): super(Encoder, self).__init__() self.latent_dim = latent_dim self._net = nn.Sequential( nn.Linear(input_dim, feats[0]), nn.GELU(), Norm_new(feats[0]), nn.Linear(feats[0], feats[1]), nn.GELU(), Norm_new(feats[1]), nn.Linear(feats[1], feats[2]), nn.GELU(), Norm_new(feats[2]), nn.Linear(feats[2], feats[3]), nn.GELU(), Norm_new(feats[3]), nn.Linear(feats[3], feats[4]), nn.GELU(), Norm_new(feats[4]), nn.Linear(feats[4], 2 * latent_dim) ) def forward(self, x): Z = self._net(x) mean, log_var = torch.split(Z, self.latent_dim, dim=-1) return mean, log_var # In[5]: class Decoder(nn.Module): def __init__(self, latent_dim, output_dim, feats=[32, 64, 128, 256, 512]): super(Decoder, self).__init__() self.output_dim = output_dim self._net = nn.Sequential( nn.Linear(latent_dim, feats[0]), nn.GELU(), Norm_new(feats[0]), nn.Linear(feats[0], feats[1]), nn.GELU(), Norm_new(feats[1]), nn.Linear(feats[1], feats[2]), nn.GELU(), Norm_new(feats[2]), nn.Linear(feats[2], feats[3]), nn.GELU(), Norm_new(feats[3]), nn.Linear(feats[3], feats[4]), nn.GELU(), Norm_new(feats[4]), nn.Linear(feats[4], output_dim), nn.Tanh() ) def forward(self, x): y = self._net(x) return y # In[6]: class Propagator(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2] def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64): """ Input : (z(t), tau) Output: z(t+tau) """ self.max_tau = max_tau super(Propagator, self).__init__() self.register_buffer('encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64 self.projector = nn.Sequential( nn.Linear(latent_dim, encoding_dim), nn.ReLU(), Norm(encoding_dim), nn.Linear(encoding_dim, encoding_dim), ) self._net = nn.Sequential( nn.Linear(encoding_dim, feats[0]), nn.ReLU(), Norm(feats[0]), nn.Linear(feats[0], feats[1]), nn.ReLU(), Norm(feats[1]), nn.Linear(feats[1], latent_dim), ) def forward(self, z, tau): zproj = self.projector(z) enc = self.encodings[tau.long()] # z: 2 # enc: 64 # [z1, z2, enc1, enc2, ..., enc64] z = zproj + enc z_tau = self._net(z) return z_tau # Doing this for the embedding for Re class Propagator_encoding(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2] def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64, max_re = 5000): """ Input : (z(t), tau, re) Output: z(t+tau) """ self.max_tau = max_tau self.max_re = max_re super(Propagator_encoding, self).__init__() self.register_buffer('tau_encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64 self.register_buffer('re_encodings', positionalencoding1d(encoding_dim, max_re)) # shape: max_re, 64 self.projector = nn.Sequential( nn.Linear(latent_dim, encoding_dim), nn.ReLU(), Norm(encoding_dim), nn.Linear(encoding_dim, encoding_dim), ) self._net = nn.Sequential( nn.Linear(encoding_dim, feats[0]), nn.ReLU(), Norm(feats[0]), nn.Linear(feats[0], feats[1]), nn.ReLU(), Norm(feats[1]), nn.Linear(feats[1], latent_dim), ) def forward(self, z, tau, re): zproj = self.projector(z) tau_enc = self.tau_encodings[tau.long()] re_enc = self.re_encodings[re.long()] # z: 2 # enc: 64 # [z1, z2, enc1, enc2, ..., enc64] z = zproj + tau_enc + re_enc #print("shape after enc addition: ", z.shape) z_tau = self._net(z) #print("shape z_tau: ", z_tau.shape) return z_tau class Propagator_concat(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2] def __init__(self, latent_dim, feats = [16, 32]): """ Input : (z(t), tau, re) Output: z(t+tau) """ super(Propagator_concat, self).__init__() self._net = nn.Sequential( nn.Linear(latent_dim + 2, feats[0]), nn.ReLU(), #Norm(feats[1]), nn.Linear(feats[0], feats[1]), nn.ReLU(), #Norm(feats[2]), nn.Linear(feats[1], latent_dim), ) def forward(self, z, tau, re): zproj = z.squeeze(1) z_ = torch.cat((zproj, tau, re), dim = 1) z_tau = self._net(z_) z_tau = z_tau[:, None, :] return z_tau class Propagator_concat_one_step(nn.Module): #taken in (z(t), Re) and outputs z(t+tau) [2, 5, 10, 2] def __init__(self, latent_dim, feats = [16, 32]): """ Input : (z(t), re) Output: z(t+1*dt) """ super(Propagator_concat_one_step, self).__init__() self._net = nn.Sequential( nn.Linear(latent_dim + 1, feats[0]), nn.ReLU(), #Norm(feats[1]), nn.Linear(feats[0], feats[1]), nn.Tanh(), #Norm(feats[2]), nn.Linear(feats[1], latent_dim), ) def forward(self, z, re): #zproj = z.squeeze(1) zproj = z z_ = torch.cat((zproj, re), dim = 1) z_tau = self._net(z_) #z_tau = z_tau[:, None, :] return z_tau class Model(nn.Module): def __init__(self, encoder, decoder, propagator): super(Model, self).__init__() self.encoder = encoder self.decoder = decoder # decoder for x(t) self.propagator = propagator # used to time march z(t) to z(t+tau) def reparameterization(self, mean, var): epsilon = torch.randn_like(var) z = mean + var * epsilon return z def forward(self, x, tau, re): mean, log_var = self.encoder(x) z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # Update small fcnn to get z(t+tau) from z(t) z_tau = self.propagator(z, tau, re) # Reconstruction x_hat = self.decoder(z) # Reconstruction of x(t) x_hat_tau = self.decoder(z_tau) return x_hat, x_hat_tau, mean, log_var, z_tau class Model_One_Step(nn.Module): # Only takes in X and Re as the parameter and not the tau as tau = 1 def __init__(self, encoder, decoder, propagator): super(Model_One_Step, self).__init__() self.encoder = encoder self.decoder = decoder # decoder for x(t) self.propagator = propagator # used to time march z(t) to z(t+tau) def reparameterization(self, mean, var): epsilon = torch.randn_like(var) z = mean + var * epsilon return z def forward(self, x, re): mean, log_var = self.encoder(x) z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # Update small fcnn to get z(t+1*dt) from z(t) -- We will use the Propagator_concat_one_step here! z_tau = self.propagator(z, re) # Reconstruction x_hat = self.decoder(z) # Reconstruction of x(t) x_hat_tau = self.decoder(z_tau) return x_hat, x_hat_tau, mean, log_var, z_tau class Model_reproduce(nn.Module): def __init__(self, encoder, decoder): super(Model_reproduce, self).__init__() self.encoder = encoder self.decoder = decoder # decoder for x(t) def reparameterization(self, mean, var): epsilon = torch.randn_like(var) z = mean + var * epsilon return z def forward(self, x): mean, log_var = self.encoder(x) z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # Reconstruction x_hat = self.decoder(z) # Reconstruction of x(t) return x_hat, mean, log_var # Define loss function def loss_function_reproduce(x, x_hat, mean, log_var): reconstruction_loss1 = nn.MSELoss()(x, x_hat) KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2)) return reconstruction_loss1, KLD # Define loss function def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var): reconstruction_loss1 = nn.MSELoss()(x, x_hat) reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau) KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2)) return reconstruction_loss1, reconstruction_loss2, KLD def count_parameters(model): total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) return total_params