Flexi-Propagator / model_v2.py
Khalid Rafiq
Add all required modules and requirements.txt
ab72d17
raw
history blame
10.7 kB
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import torch
import torch.nn as nn
import math
import torch
# In[2]:
def positionalencoding1d(d_model, length):
"""
:param d_model: dimension of the model
:param length: length of positions
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model))
pe = torch.zeros(length, d_model)
position = torch.arange(0, length).unsqueeze(1)
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
-(math.log(10000.0) / d_model)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
return pe
# In[3]:
class Norm(nn.Module):
def __init__(self, num_channels, num_groups=4):
super(Norm, self).__init__()
self.norm = nn.GroupNorm(num_groups, num_channels)
def forward(self, x):
return self.norm(x.permute(0,2,1)).permute(0,2,1)
# In[4]:
class Norm_new(nn.Module):
def __init__(self, num_channels, num_groups=4):
super(Norm_new, self).__init__()
self.norm = nn.GroupNorm(num_groups, num_channels)
def forward(self, x):
if x.dim() == 2:
# Reshape to (batch_size, num_channels, 1)
x = x.unsqueeze(-1)
x = self.norm(x)
# Reshape back to (batch_size, num_channels)
x = x.squeeze(-1)
else:
x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1)
return x
class Encoder(nn.Module):
def __init__(self, input_dim, latent_dim=2, feats=[512, 256, 128, 64, 32]):
super(Encoder, self).__init__()
self.latent_dim = latent_dim
self._net = nn.Sequential(
nn.Linear(input_dim, feats[0]),
nn.GELU(),
Norm_new(feats[0]),
nn.Linear(feats[0], feats[1]),
nn.GELU(),
Norm_new(feats[1]),
nn.Linear(feats[1], feats[2]),
nn.GELU(),
Norm_new(feats[2]),
nn.Linear(feats[2], feats[3]),
nn.GELU(),
Norm_new(feats[3]),
nn.Linear(feats[3], feats[4]),
nn.GELU(),
Norm_new(feats[4]),
nn.Linear(feats[4], 2 * latent_dim)
)
def forward(self, x):
Z = self._net(x)
mean, log_var = torch.split(Z, self.latent_dim, dim=-1)
return mean, log_var
# In[5]:
class Decoder(nn.Module):
def __init__(self, latent_dim, output_dim, feats=[32, 64, 128, 256, 512]):
super(Decoder, self).__init__()
self.output_dim = output_dim
self._net = nn.Sequential(
nn.Linear(latent_dim, feats[0]),
nn.GELU(),
Norm_new(feats[0]),
nn.Linear(feats[0], feats[1]),
nn.GELU(),
Norm_new(feats[1]),
nn.Linear(feats[1], feats[2]),
nn.GELU(),
Norm_new(feats[2]),
nn.Linear(feats[2], feats[3]),
nn.GELU(),
Norm_new(feats[3]),
nn.Linear(feats[3], feats[4]),
nn.GELU(),
Norm_new(feats[4]),
nn.Linear(feats[4], output_dim),
nn.Tanh()
)
def forward(self, x):
y = self._net(x)
return y
# In[6]:
class Propagator(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64):
"""
Input : (z(t), tau)
Output: z(t+tau)
"""
self.max_tau = max_tau
super(Propagator, self).__init__()
self.register_buffer('encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64
self.projector = nn.Sequential(
nn.Linear(latent_dim, encoding_dim),
nn.ReLU(),
Norm(encoding_dim),
nn.Linear(encoding_dim, encoding_dim),
)
self._net = nn.Sequential(
nn.Linear(encoding_dim, feats[0]),
nn.ReLU(),
Norm(feats[0]),
nn.Linear(feats[0], feats[1]),
nn.ReLU(),
Norm(feats[1]),
nn.Linear(feats[1], latent_dim),
)
def forward(self, z, tau):
zproj = self.projector(z)
enc = self.encodings[tau.long()]
# z: 2
# enc: 64
# [z1, z2, enc1, enc2, ..., enc64]
z = zproj + enc
z_tau = self._net(z)
return z_tau
# Doing this for the embedding for Re
class Propagator_encoding(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64, max_re = 5000):
"""
Input : (z(t), tau, re)
Output: z(t+tau)
"""
self.max_tau = max_tau
self.max_re = max_re
super(Propagator_encoding, self).__init__()
self.register_buffer('tau_encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64
self.register_buffer('re_encodings', positionalencoding1d(encoding_dim, max_re)) # shape: max_re, 64
self.projector = nn.Sequential(
nn.Linear(latent_dim, encoding_dim),
nn.ReLU(),
Norm(encoding_dim),
nn.Linear(encoding_dim, encoding_dim),
)
self._net = nn.Sequential(
nn.Linear(encoding_dim, feats[0]),
nn.ReLU(),
Norm(feats[0]),
nn.Linear(feats[0], feats[1]),
nn.ReLU(),
Norm(feats[1]),
nn.Linear(feats[1], latent_dim),
)
def forward(self, z, tau, re):
zproj = self.projector(z)
tau_enc = self.tau_encodings[tau.long()]
re_enc = self.re_encodings[re.long()]
# z: 2
# enc: 64
# [z1, z2, enc1, enc2, ..., enc64]
z = zproj + tau_enc + re_enc
#print("shape after enc addition: ", z.shape)
z_tau = self._net(z)
#print("shape z_tau: ", z_tau.shape)
return z_tau
class Propagator_concat(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
def __init__(self, latent_dim, feats = [16, 32]):
"""
Input : (z(t), tau, re)
Output: z(t+tau)
"""
super(Propagator_concat, self).__init__()
self._net = nn.Sequential(
nn.Linear(latent_dim + 2, feats[0]),
nn.ReLU(),
#Norm(feats[1]),
nn.Linear(feats[0], feats[1]),
nn.ReLU(),
#Norm(feats[2]),
nn.Linear(feats[1], latent_dim),
)
def forward(self, z, tau, re):
zproj = z.squeeze(1)
z_ = torch.cat((zproj, tau, re), dim = 1)
z_tau = self._net(z_)
z_tau = z_tau[:, None, :]
return z_tau
class Propagator_concat_one_step(nn.Module): #taken in (z(t), Re) and outputs z(t+tau) [2, 5, 10, 2]
def __init__(self, latent_dim, feats = [16, 32]):
"""
Input : (z(t), re)
Output: z(t+1*dt)
"""
super(Propagator_concat_one_step, self).__init__()
self._net = nn.Sequential(
nn.Linear(latent_dim + 1, feats[0]),
nn.ReLU(),
#Norm(feats[1]),
nn.Linear(feats[0], feats[1]),
nn.Tanh(),
#Norm(feats[2]),
nn.Linear(feats[1], latent_dim),
)
def forward(self, z, re):
#zproj = z.squeeze(1)
zproj = z
z_ = torch.cat((zproj, re), dim = 1)
z_tau = self._net(z_)
#z_tau = z_tau[:, None, :]
return z_tau
class Model(nn.Module):
def __init__(self, encoder, decoder, propagator):
super(Model, self).__init__()
self.encoder = encoder
self.decoder = decoder # decoder for x(t)
self.propagator = propagator # used to time march z(t) to z(t+tau)
def reparameterization(self, mean, var):
epsilon = torch.randn_like(var)
z = mean + var * epsilon
return z
def forward(self, x, tau, re):
mean, log_var = self.encoder(x)
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
# Update small fcnn to get z(t+tau) from z(t)
z_tau = self.propagator(z, tau, re)
# Reconstruction
x_hat = self.decoder(z) # Reconstruction of x(t)
x_hat_tau = self.decoder(z_tau)
return x_hat, x_hat_tau, mean, log_var, z_tau
class Model_One_Step(nn.Module): # Only takes in X and Re as the parameter and not the tau as tau = 1
def __init__(self, encoder, decoder, propagator):
super(Model_One_Step, self).__init__()
self.encoder = encoder
self.decoder = decoder # decoder for x(t)
self.propagator = propagator # used to time march z(t) to z(t+tau)
def reparameterization(self, mean, var):
epsilon = torch.randn_like(var)
z = mean + var * epsilon
return z
def forward(self, x, re):
mean, log_var = self.encoder(x)
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
# Update small fcnn to get z(t+1*dt) from z(t) -- We will use the Propagator_concat_one_step here!
z_tau = self.propagator(z, re)
# Reconstruction
x_hat = self.decoder(z) # Reconstruction of x(t)
x_hat_tau = self.decoder(z_tau)
return x_hat, x_hat_tau, mean, log_var, z_tau
class Model_reproduce(nn.Module):
def __init__(self, encoder, decoder):
super(Model_reproduce, self).__init__()
self.encoder = encoder
self.decoder = decoder # decoder for x(t)
def reparameterization(self, mean, var):
epsilon = torch.randn_like(var)
z = mean + var * epsilon
return z
def forward(self, x):
mean, log_var = self.encoder(x)
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
# Reconstruction
x_hat = self.decoder(z) # Reconstruction of x(t)
return x_hat, mean, log_var
# Define loss function
def loss_function_reproduce(x, x_hat, mean, log_var):
reconstruction_loss1 = nn.MSELoss()(x, x_hat)
KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2))
return reconstruction_loss1, KLD
# Define loss function
def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var):
reconstruction_loss1 = nn.MSELoss()(x, x_hat)
reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau)
KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2))
return reconstruction_loss1, reconstruction_loss2, KLD
def count_parameters(model):
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params