#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np import torch import torch.nn as nn import time import math import torch num_time_steps = 500 x = np.linspace(0.0,1.0,num=128) dx = 1.0/np.shape(x)[0] tsteps = np.linspace(0.0,2.0,num=num_time_steps) dt = 2.0/np.shape(tsteps)[0] class AE_Encoder(nn.Module): def __init__(self, input_dim, latent_dim=2, feats=[512, 256, 128, 64, 32]): super(AE_Encoder, self).__init__() self.latent_dim = latent_dim self._net = nn.Sequential( nn.Linear(input_dim, feats[0]), nn.GELU(), nn.Linear(feats[0], feats[1]), nn.GELU(), nn.Linear(feats[1], feats[2]), nn.GELU(), nn.Linear(feats[2], feats[3]), nn.GELU(), nn.Linear(feats[3], feats[4]), nn.GELU(), nn.Linear(feats[4], latent_dim) ) def forward(self, x): Z = self._net(x) return Z class AE_Decoder(nn.Module): def __init__(self, latent_dim, output_dim, feats=[32, 64, 128, 256, 512]): super(AE_Decoder, self).__init__() self.output_dim = output_dim self._net = nn.Sequential( nn.Linear(latent_dim, feats[0]), nn.GELU(), nn.Linear(feats[0], feats[1]), nn.GELU(), nn.Linear(feats[1], feats[2]), nn.GELU(), nn.Linear(feats[2], feats[3]), nn.GELU(), nn.Linear(feats[3], feats[4]), nn.GELU(), nn.Linear(feats[4], output_dim), ) def forward(self, x): y = self._net(x) return y class AE_Model(nn.Module): def __init__(self, encoder, decoder): super(AE_Model, self).__init__() self.encoder = encoder self.decoder = decoder # decoder for x(t) def forward(self, x): z = self.encoder(x) # Reconstruction x_hat = self.decoder(z) # Reconstruction of x(t) return x_hat class PytorchLSTM(nn.Module): def __init__(self, input_dim=3, hidden_dim=40, output_dim=2): super().__init__() # First LSTM: simulates return_sequences=True self.lstm1 = nn.LSTM(input_dim, hidden_dim, batch_first=True) # Second LSTM: simulates return_sequences=False self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, batch_first=True) # Dense layer self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): """ x shape: [batch_size, time_window, input_dim] """ # LSTM1 (return_sequences=True) out1, (h1, c1) = self.lstm1(x) # out1 shape: [batch_size, time_window, hidden_dim] # LSTM2 (return_sequences=False -> we only use the last time step) out2, (h2, c2) = self.lstm2(out1) # out2 shape: [batch_size, time_window, hidden_dim] # Last timestep (since we didn't set return_sequences=True) # is effectively out2[:, -1, :], but PyTorch LSTM always returns full seq unless you slice. last_timestep = out2[:, -1, :] # shape: [batch_size, hidden_dim] # Dense -> 2 outputs output = self.fc(last_timestep) # shape: [batch_size, 2] return output def measure_lstm_prediction_time( decoder, lstm_model, lstm_testing_data, sim_num, final_time, time_window=40 ): """ Predicts up to `final_time` in a walk-forward manner for simulation `sim_num`, measures the elapsed time, and returns the final predicted latent + the true latent. Parameters ---------- decoder : torch.nn.Module The trained weights of the decoder model : torch.nn.Module Trained PyTorch LSTM model. We'll set model.eval() inside. lstm_testing_data : np.ndarray Shape (num_test_snapshots, num_time_steps, 3). The last dimension typically holds (2 latents + 1 param) or similar. sim_num : int Which simulation index to use (e.g., 0 for the first). final_time : int The final timestep index you want to predict up to (>= time_window). For example, if time_window=10 and final_time=20, we will predict from t=10..19. time_window : int Size of the rolling window (default=40). Returns ------- float Elapsed time (seconds) for performing the predictions from t=time_window up to t=final_time. np.ndarray The final predicted latent at time=final_time (shape (2,)). np.ndarray The true latent at time=final_time (shape (2,)). """ # Basic shape info num_time_steps = lstm_testing_data.shape[1] if final_time > num_time_steps: raise ValueError( f"final_time={final_time} exceeds available time steps={num_time_steps}." ) if final_time < time_window: raise ValueError( f"final_time={final_time} is less than time_window={time_window}, no prediction needed." ) # Initialize the rolling window with first `time_window` steps input_seq = np.zeros((1, time_window, 3), dtype=np.float32) input_seq[0, :, :] = lstm_testing_data[sim_num, 0:time_window, :] lstm_model.eval() # inference mode final_pred = None # store the final predicted latent start_time = time.time() with torch.no_grad(): # Predict from t=time_window to t=final_time-1 # so that at the end of the loop we've generated a prediction for index final_time. # If you want the model's prediction at final_time itself, we do a loop up to final_time. for t in range(time_window, final_time): inp_tensor = torch.from_numpy(input_seq).float() # shape [1, 10, 3] pred = lstm_model(inp_tensor) # shape [1, 2] pred_np = pred.numpy()[0, :] # shape (2,) # Shift the rolling window temp = input_seq[0, 1:time_window, :].copy() input_seq[0, 0:time_window - 1, :] = temp input_seq[0, time_window - 1, 0:2] = pred_np # Keep track of the last prediction final_pred = pred_np x_hat_tau_pred = decoder(torch.tensor(final_pred, dtype = torch.float32)) end_time = time.time() elapsed = end_time - start_time # final_pred is the LSTM's predicted latent for step `final_time`. # The *true* latent at that time is: final_true = lstm_testing_data[sim_num, final_time, 0:2] # shape (2,) return elapsed, final_pred, final_true def collect_snapshots(Rnum): snapshot_matrix = np.zeros(shape=(np.shape(x)[0],np.shape(tsteps)[0])) trange = np.arange(np.shape(tsteps)[0]) for t in trange: snapshot_matrix[:,t] = exact_solution(Rnum,tsteps[t])[:] return snapshot_matrix def collect_multiparam_snapshots_train(): rnum_vals = np.arange(900,2900,100) rsnap = 0 for rnum_val in rnum_vals: snapshots_temp = np.transpose(collect_snapshots(rnum_val)) if rsnap == 0: all_snapshots = snapshots_temp else: all_snapshots = np.concatenate((all_snapshots,snapshots_temp),axis=0) rsnap = rsnap + 1 return all_snapshots, rnum_vals/1000 def collect_multiparam_snapshots_test(): rnum_vals = np.arange(1050,2850,200) rsnap = 0 for rnum_val in rnum_vals: snapshots_temp = np.transpose(collect_snapshots(rnum_val)) if rsnap == 0: all_snapshots = snapshots_temp else: all_snapshots = np.concatenate((all_snapshots,snapshots_temp),axis=0) rsnap = rsnap + 1 return all_snapshots, rnum_vals/1000 return elapsed, final_pred, final_true def exact_solution(Rnum,t): x = np.linspace(0.0,1.0,num=128) t0 = np.exp(Rnum/8.0) return (x/(t+1))/(1.0+np.sqrt((t+1)/t0)*np.exp(Rnum*(x*x)/(4.0*t+4)))