Khalid Rafiq commited on
Commit
ab72d17
·
1 Parent(s): ad06c72

Add all required modules and requirements.txt

Browse files
LSTM_model.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import time
11
+ import math
12
+ import torch
13
+
14
+ num_time_steps = 500
15
+ x = np.linspace(0.0,1.0,num=128)
16
+ dx = 1.0/np.shape(x)[0]
17
+ tsteps = np.linspace(0.0,2.0,num=num_time_steps)
18
+ dt = 2.0/np.shape(tsteps)[0]
19
+
20
+ class AE_Encoder(nn.Module):
21
+ def __init__(self, input_dim, latent_dim=2, feats=[512, 256, 128, 64, 32]):
22
+ super(AE_Encoder, self).__init__()
23
+ self.latent_dim = latent_dim
24
+ self._net = nn.Sequential(
25
+ nn.Linear(input_dim, feats[0]),
26
+ nn.GELU(),
27
+ nn.Linear(feats[0], feats[1]),
28
+ nn.GELU(),
29
+ nn.Linear(feats[1], feats[2]),
30
+ nn.GELU(),
31
+ nn.Linear(feats[2], feats[3]),
32
+ nn.GELU(),
33
+ nn.Linear(feats[3], feats[4]),
34
+ nn.GELU(),
35
+ nn.Linear(feats[4], latent_dim)
36
+ )
37
+
38
+ def forward(self, x):
39
+ Z = self._net(x)
40
+ return Z
41
+
42
+
43
+ class AE_Decoder(nn.Module):
44
+ def __init__(self, latent_dim, output_dim, feats=[32, 64, 128, 256, 512]):
45
+ super(AE_Decoder, self).__init__()
46
+ self.output_dim = output_dim
47
+ self._net = nn.Sequential(
48
+ nn.Linear(latent_dim, feats[0]),
49
+ nn.GELU(),
50
+ nn.Linear(feats[0], feats[1]),
51
+ nn.GELU(),
52
+ nn.Linear(feats[1], feats[2]),
53
+ nn.GELU(),
54
+ nn.Linear(feats[2], feats[3]),
55
+ nn.GELU(),
56
+ nn.Linear(feats[3], feats[4]),
57
+ nn.GELU(),
58
+ nn.Linear(feats[4], output_dim),
59
+ )
60
+
61
+ def forward(self, x):
62
+ y = self._net(x)
63
+ return y
64
+
65
+
66
+ class AE_Model(nn.Module):
67
+ def __init__(self, encoder, decoder):
68
+ super(AE_Model, self).__init__()
69
+ self.encoder = encoder
70
+ self.decoder = decoder # decoder for x(t)
71
+
72
+ def forward(self, x):
73
+ z = self.encoder(x)
74
+ # Reconstruction
75
+ x_hat = self.decoder(z) # Reconstruction of x(t)
76
+
77
+ return x_hat
78
+
79
+
80
+ class PytorchLSTM(nn.Module):
81
+ def __init__(self, input_dim=3, hidden_dim=40, output_dim=2):
82
+ super().__init__()
83
+ # First LSTM: simulates return_sequences=True
84
+ self.lstm1 = nn.LSTM(input_dim, hidden_dim, batch_first=True)
85
+ # Second LSTM: simulates return_sequences=False
86
+ self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
87
+ # Dense layer
88
+ self.fc = nn.Linear(hidden_dim, output_dim)
89
+
90
+ def forward(self, x):
91
+ """
92
+ x shape: [batch_size, time_window, input_dim]
93
+ """
94
+ # LSTM1 (return_sequences=True)
95
+ out1, (h1, c1) = self.lstm1(x)
96
+ # out1 shape: [batch_size, time_window, hidden_dim]
97
+
98
+ # LSTM2 (return_sequences=False -> we only use the last time step)
99
+ out2, (h2, c2) = self.lstm2(out1)
100
+ # out2 shape: [batch_size, time_window, hidden_dim]
101
+ # Last timestep (since we didn't set return_sequences=True)
102
+ # is effectively out2[:, -1, :], but PyTorch LSTM always returns full seq unless you slice.
103
+
104
+ last_timestep = out2[:, -1, :] # shape: [batch_size, hidden_dim]
105
+
106
+ # Dense -> 2 outputs
107
+ output = self.fc(last_timestep) # shape: [batch_size, 2]
108
+ return output
109
+
110
+ def measure_lstm_prediction_time(
111
+ decoder,
112
+ lstm_model,
113
+ lstm_testing_data,
114
+ sim_num,
115
+ final_time,
116
+ time_window=40
117
+ ):
118
+ """
119
+ Predicts up to `final_time` in a walk-forward manner for simulation `sim_num`,
120
+ measures the elapsed time, and returns the final predicted latent + the true latent.
121
+
122
+ Parameters
123
+ ----------
124
+ decoder : torch.nn.Module
125
+ The trained weights of the decoder
126
+ model : torch.nn.Module
127
+ Trained PyTorch LSTM model. We'll set model.eval() inside.
128
+ lstm_testing_data : np.ndarray
129
+ Shape (num_test_snapshots, num_time_steps, 3).
130
+ The last dimension typically holds (2 latents + 1 param) or similar.
131
+ sim_num : int
132
+ Which simulation index to use (e.g., 0 for the first).
133
+ final_time : int
134
+ The final timestep index you want to predict up to (>= time_window).
135
+ For example, if time_window=10 and final_time=20, we will predict from t=10..19.
136
+ time_window : int
137
+ Size of the rolling window (default=40).
138
+
139
+ Returns
140
+ -------
141
+ float
142
+ Elapsed time (seconds) for performing the predictions from t=time_window up to t=final_time.
143
+ np.ndarray
144
+ The final predicted latent at time=final_time (shape (2,)).
145
+ np.ndarray
146
+ The true latent at time=final_time (shape (2,)).
147
+ """
148
+
149
+ # Basic shape info
150
+ num_time_steps = lstm_testing_data.shape[1]
151
+ if final_time > num_time_steps:
152
+ raise ValueError(
153
+ f"final_time={final_time} exceeds available time steps={num_time_steps}."
154
+ )
155
+ if final_time < time_window:
156
+ raise ValueError(
157
+ f"final_time={final_time} is less than time_window={time_window}, no prediction needed."
158
+ )
159
+
160
+ # Initialize the rolling window with first `time_window` steps
161
+ input_seq = np.zeros((1, time_window, 3), dtype=np.float32)
162
+ input_seq[0, :, :] = lstm_testing_data[sim_num, 0:time_window, :]
163
+
164
+ lstm_model.eval() # inference mode
165
+
166
+ final_pred = None # store the final predicted latent
167
+ start_time = time.time()
168
+
169
+ with torch.no_grad():
170
+ # Predict from t=time_window to t=final_time-1
171
+ # so that at the end of the loop we've generated a prediction for index final_time.
172
+ # If you want the model's prediction at final_time itself, we do a loop up to final_time.
173
+ for t in range(time_window, final_time):
174
+ inp_tensor = torch.from_numpy(input_seq).float() # shape [1, 10, 3]
175
+ pred = lstm_model(inp_tensor) # shape [1, 2]
176
+ pred_np = pred.numpy()[0, :] # shape (2,)
177
+
178
+ # Shift the rolling window
179
+ temp = input_seq[0, 1:time_window, :].copy()
180
+ input_seq[0, 0:time_window - 1, :] = temp
181
+ input_seq[0, time_window - 1, 0:2] = pred_np
182
+
183
+ # Keep track of the last prediction
184
+ final_pred = pred_np
185
+
186
+ x_hat_tau_pred = decoder(torch.tensor(final_pred, dtype = torch.float32))
187
+
188
+ end_time = time.time()
189
+
190
+ elapsed = end_time - start_time
191
+
192
+ # final_pred is the LSTM's predicted latent for step `final_time`.
193
+ # The *true* latent at that time is:
194
+ final_true = lstm_testing_data[sim_num, final_time, 0:2] # shape (2,)
195
+
196
+ return elapsed, final_pred, final_true
197
+
198
+
199
+ def collect_snapshots(Rnum):
200
+ snapshot_matrix = np.zeros(shape=(np.shape(x)[0],np.shape(tsteps)[0]))
201
+
202
+ trange = np.arange(np.shape(tsteps)[0])
203
+ for t in trange:
204
+ snapshot_matrix[:,t] = exact_solution(Rnum,tsteps[t])[:]
205
+
206
+ return snapshot_matrix
207
+
208
+ def collect_multiparam_snapshots_train():
209
+ rnum_vals = np.arange(900,2900,100)
210
+
211
+ rsnap = 0
212
+ for rnum_val in rnum_vals:
213
+ snapshots_temp = np.transpose(collect_snapshots(rnum_val))
214
+
215
+ if rsnap == 0:
216
+ all_snapshots = snapshots_temp
217
+ else:
218
+
219
+ all_snapshots = np.concatenate((all_snapshots,snapshots_temp),axis=0)
220
+
221
+ rsnap = rsnap + 1
222
+ return all_snapshots, rnum_vals/1000
223
+
224
+ def collect_multiparam_snapshots_test():
225
+ rnum_vals = np.arange(1050,2850,200)
226
+
227
+ rsnap = 0
228
+ for rnum_val in rnum_vals:
229
+ snapshots_temp = np.transpose(collect_snapshots(rnum_val))
230
+
231
+ if rsnap == 0:
232
+ all_snapshots = snapshots_temp
233
+ else:
234
+
235
+ all_snapshots = np.concatenate((all_snapshots,snapshots_temp),axis=0)
236
+
237
+ rsnap = rsnap + 1
238
+ return all_snapshots, rnum_vals/1000
239
+
240
+
241
+
242
+ return elapsed, final_pred, final_true
243
+
244
+ def exact_solution(Rnum,t):
245
+ x = np.linspace(0.0,1.0,num=128)
246
+ t0 = np.exp(Rnum/8.0)
247
+ return (x/(t+1))/(1.0+np.sqrt((t+1)/t0)*np.exp(Rnum*(x*x)/(4.0*t+4)))
248
+
249
+
__pycache__/LSTM_model.cpython-310.pyc ADDED
Binary file (6.23 kB). View file
 
__pycache__/config_adv_dif.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
__pycache__/config_burgers.cpython-310.pyc ADDED
Binary file (1.19 kB). View file
 
__pycache__/data_adv_dif.cpython-310.pyc ADDED
Binary file (7.21 kB). View file
 
__pycache__/data_burgers.cpython-310.pyc ADDED
Binary file (7.42 kB). View file
 
__pycache__/model_adv_dif.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
__pycache__/model_io_adv_dif.cpython-310.pyc ADDED
Binary file (886 Bytes). View file
 
__pycache__/model_io_burgers.cpython-310.pyc ADDED
Binary file (883 Bytes). View file
 
__pycache__/model_v2.cpython-310.pyc ADDED
Binary file (9.64 kB). View file
 
config_adv_dif.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+
4
+ @dataclass
5
+ class Config:
6
+ # default values. DO NOT TOUCH
7
+ name: str = 'FlexiPropagator_2D'
8
+ latent_dim: int = 3
9
+ batch_size: int = 64
10
+ lr: float = 3e-4
11
+ num_epochs: int = 25
12
+ num_time_steps: int = 500
13
+
14
+ gamma: float = 3.25
15
+ beta: float = 1e-3
16
+
17
+ val_every: float = 0.25
18
+ plot_train_every: float = 0.01
19
+
20
+ save_dir: str = 'checkpoints'
21
+
22
+ def load_config(path):
23
+ with open(path, 'r') as f:
24
+ config = json.load(f)
25
+ return Config(**config)
config_burgers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+
4
+ @dataclass
5
+ class Config:
6
+ # default values. DO NOT TOUCH
7
+ name: str = 'FlexiPropagator'
8
+ latent_dim: int = 2
9
+ input_dim: int = 128
10
+ batch_size: int = 128
11
+ lr: float = 3e-4
12
+ num_epochs: int = 200
13
+ n_samples_train: int = 8_00_000
14
+
15
+ num_time_steps: int = 500
16
+
17
+ tau_left_fraction: float = 0.35
18
+ tau_right_fraction: float = 0.85
19
+
20
+ gamma: float = 3.25
21
+ beta: float = 1e-4
22
+
23
+ val_every: float = 0.25
24
+ plot_train_every: float = 0.01
25
+
26
+ save_dir: str = 'checkpoints'
27
+
28
+ def load_config(path):
29
+ with open(path, 'r') as f:
30
+ config = json.load(f)
31
+ return Config(**config)
data_adv_dif.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from dataclasses import dataclass, asdict
7
+ import json
8
+
9
+
10
+ # Rnum = 1000
11
+ num_time_steps = 500
12
+
13
+ def get_dt(num_time_steps):
14
+ return 2.0/num_time_steps
15
+
16
+ dt = get_dt(num_time_steps)
17
+
18
+
19
+ def exact_solution(alpha, t, L=2.0, Nx=128, Ny=128, c=1.0):
20
+ nu = 1.0 / alpha
21
+ x_vals = np.linspace(-L, L, Nx)
22
+ y_vals = np.linspace(-L, L, Ny)
23
+ X, Y = np.meshgrid(x_vals, y_vals)
24
+ if t <= 0:
25
+ return np.zeros_like(X)
26
+ rx = X - c * t
27
+ ry = Y
28
+ r2 = rx**2 + ry**2
29
+ denominator = 4.0 * nu * t
30
+ amplitude = 1.0 / (4.0 * np.pi * nu * t)
31
+ U = amplitude * np.exp(-r2 / denominator)
32
+ return U
33
+
34
+
35
+ class AdvectionDiffussionDataset:
36
+ def __init__(self,
37
+ X: np.ndarray = None,
38
+ X_tau: np.ndarray = None,
39
+ t_values: np.ndarray = None,
40
+ tau_values: np.ndarray = None,
41
+ alpha_values: np.ndarray = None):
42
+ self.X = X
43
+ self.X_tau = X_tau
44
+ self.t_values = t_values
45
+ self.tau_values = tau_values
46
+ self.alpha_values = alpha_values
47
+
48
+ def append(self, other):
49
+ self.X = np.concatenate([self.X, other.X]) if self.X is not None else other.X
50
+ self.X_tau = np.concatenate([self.X_tau, other.X_tau]) if self.X_tau is not None else other.X_tau
51
+ self.t_values = np.concatenate([self.t_values, other.t_values]) if self.t_values is not None else other.t_values
52
+ self.tau_values = np.concatenate([self.tau_values, other.tau_values]) if self.tau_values is not None else other.tau_values
53
+ self.alpha_values = np.concatenate([self.alpha_values, other.alpha_values]) if self.alpha_values is not None else other.alpha_values
54
+
55
+ @dataclass
56
+ class IntervalSplit:
57
+ interpolation: tuple
58
+ extrapolation_left: tuple
59
+ extrapolation_right: tuple
60
+
61
+ def prepare_adv_diff_dataset(alpha_range=(0.01, 10), tau_range=(150, 400), dt=dt, n_samples=500):
62
+ X = []
63
+ X_tau = []
64
+ t_values = []
65
+ tau_values = []
66
+ alpha_values = []
67
+ TRANGE = (0.01, 2.0)
68
+ while len(X) < n_samples:
69
+ # sample alpha uniformly
70
+ alpha = np.random.uniform(*alpha_range)
71
+ t = np.random.uniform(*TRANGE)
72
+ x_t = exact_solution(alpha, t)
73
+ tau = np.random.randint(*tau_range)
74
+ x_tau = exact_solution(alpha, t+(tau*dt))
75
+
76
+ X.append(x_t)
77
+ X_tau.append(x_tau)
78
+ t_values.append(t)
79
+ tau_values.append(tau)
80
+ alpha_values.append(alpha)
81
+
82
+ X = np.array(X)
83
+ X_tau = np.array(X_tau)
84
+ t_values = np.array(t_values)
85
+ tau_values = np.array(tau_values)
86
+ alpha_values = np.array(alpha_values)
87
+ dataset = AdvectionDiffussionDataset(X, X_tau, t_values, tau_values, alpha_values)
88
+ return dataset
89
+
90
+
91
+ def train_test_split_range(interval, interpolation_span=0.1, extrapolation_left_span=0.1, extrapolation_right_span=0.1):
92
+ """
93
+ Split the range into train and test ranges
94
+ We have three test folds:
95
+ 1. Interpolation fold: Re and tau values are within the training (min, max) range but not in the training set
96
+ We sample an interval of length x_interpolation_span% randomly from the total range
97
+ 2. Extrapolation fold: Re and tau values are outside the training (min, max) range
98
+ We sample two intervals of length x_extrapolation_right_span% and x_extrapolation_left_span% from the total range
99
+ 3. Validation fold: Re and tau values are randomly sampled from the total set
100
+
101
+ Overall interval looks like:
102
+ Extrapolation_left_test | normal | Interpolation_test | normal | Extrapolation_right_test
103
+ (min, extrapolation_left) | (extraplation_left, interpolation_min) | (interpolation_min, interpolation_max) | (interpolation_max, extrapolation_right) | (extrapolation_right, max)
104
+ and
105
+ train, val = split(normal, val_split)
106
+ """
107
+ r_min, r_max = interval
108
+ length = (r_max-r_min)
109
+ extra_left_length = extrapolation_left_span * length
110
+ extra_right_length = extrapolation_right_span * length
111
+ inter_length = interpolation_span * length
112
+
113
+ extrapolation_left = (r_min, r_min + extra_left_length)
114
+ extrapolation_right = (r_max - extra_right_length, r_max)
115
+
116
+ interpolation_min = np.random.uniform(extrapolation_left[1], extrapolation_right[0] - inter_length)
117
+ interpolation = (interpolation_min, interpolation_min + inter_length)
118
+
119
+ train_ranges = [(extrapolation_left[1], interpolation[0]), (interpolation[1], extrapolation_right[0])]
120
+ return IntervalSplit(interpolation, extrapolation_left, extrapolation_right), train_ranges
121
+
122
+ def get_train_ranges(interval_split):
123
+ return [
124
+ (interval_split.extrapolation_left[1], interval_split.interpolation[0]),
125
+ (interval_split.interpolation[1], interval_split.extrapolation_right[0])
126
+ ]
127
+
128
+ def get_train_val_test_folds(alpha_range, tau_range,
129
+ alpha_interpolation_span=0.10,
130
+ alpha_extrapolation_left_span=0.10,
131
+ alpha_extrapolation_right_span=0.10,
132
+ tau_interpolation_span=0.10,
133
+ tau_extrapolation_left_span=0.10,
134
+ tau_extrapolation_right_span=0.10,
135
+ n_samples_train=500,
136
+ n_samples_val=200):
137
+ """
138
+ Generate train (4 sub-regions) and val (left extrp, interp, right extrp
139
+ for alpha x left extrp, interp, right extrp for tau) datasets.
140
+
141
+ Returns:
142
+ dataset_train : AdvectionDiffussionDataset
143
+ dataset_val : AdvectionDiffussionDataset
144
+ alpha_interval_split: IntervalSplit
145
+ tau_interval_split : IntervalSplit
146
+ """
147
+
148
+ # ---------------------------------------------------------------------
149
+ # 1) Split alpha into 4 regions: left extrp, interp, right extrp, train
150
+ # 2) Split tau into 4 regions: left extrp, interp, right extrp, train
151
+ # ---------------------------------------------------------------------
152
+ alpha_interval_split, alpha_train_ranges = train_test_split_range(
153
+ alpha_range,
154
+ alpha_interpolation_span,
155
+ alpha_extrapolation_left_span,
156
+ alpha_extrapolation_right_span
157
+ )
158
+ tau_interval_split, tau_train_ranges = train_test_split_range(
159
+ tau_range,
160
+ tau_interpolation_span,
161
+ tau_extrapolation_left_span,
162
+ tau_extrapolation_right_span
163
+ )
164
+
165
+ # alpha_train_ranges and tau_train_ranges each have 2 intervals:
166
+ # alpha_train_ranges = [ (a1_lo, a1_hi), (a2_lo, a2_hi) ]
167
+ # tau_train_ranges = [ (t1_lo, t1_hi), (t2_lo, t2_hi) ]
168
+ #
169
+ # Meanwhile, alpha_interval_split has:
170
+ # alpha_interval_split.extrapolation_left = (a_left_lo, a_left_hi)
171
+ # alpha_interval_split.interpolation = (a_int_lo, a_int_hi)
172
+ # alpha_interval_split.extrapolation_right = (a_right_lo, a_right_hi)
173
+ # and similarly for tau_interval_split.
174
+
175
+ # -------------------------------------------------------------
176
+ # 3) Build the TRAIN dataset from the Cartesian product
177
+ # of alpha_train_ranges x tau_train_ranges => 4 combos
178
+ # -------------------------------------------------------------
179
+ dataset_train = AdvectionDiffussionDataset()
180
+ for alpha_subrange in alpha_train_ranges: # 2 intervals
181
+ for tau_subrange in tau_train_ranges: # 2 intervals
182
+ subset = prepare_adv_diff_dataset(
183
+ alpha_range=alpha_subrange,
184
+ tau_range=tau_subrange,
185
+ n_samples=n_samples_train
186
+ )
187
+ dataset_train.append(subset)
188
+
189
+ # -------------------------------------------------------------
190
+ # 4) Build the VAL dataset from the leftover intervals:
191
+ # alpha in { left extrp, interp, right extrp }
192
+ # x tau in { left extrp, interp, right extrp } => up to 9 combos
193
+ # -------------------------------------------------------------
194
+ alpha_val_intervals = [
195
+ alpha_interval_split.extrapolation_left,
196
+ alpha_interval_split.interpolation,
197
+ alpha_interval_split.extrapolation_right
198
+ ]
199
+ tau_val_intervals = [
200
+ tau_interval_split.extrapolation_left,
201
+ tau_interval_split.interpolation,
202
+ tau_interval_split.extrapolation_right
203
+ ]
204
+
205
+ dataset_val = AdvectionDiffussionDataset()
206
+
207
+ for a_val_range in alpha_val_intervals:
208
+ for t_val_range in tau_val_intervals:
209
+ subset_val = prepare_adv_diff_dataset(
210
+ alpha_range=a_val_range,
211
+ tau_range=t_val_range,
212
+ n_samples=n_samples_val
213
+ )
214
+ dataset_val.append(subset_val)
215
+
216
+ return dataset_train, dataset_val, alpha_interval_split, tau_interval_split
217
+
218
+
219
+ def plot_sample(dataset, i):
220
+ """
221
+ Plot a sample pair from the dataset.
222
+ """
223
+ X = dataset.X
224
+ X_tau = dataset.X_tau
225
+ t_values = dataset.t_values
226
+ tau_values = dataset.tau_values
227
+ alpha_values = dataset.alpha_values
228
+
229
+ print("Shape of X:", X.shape)
230
+
231
+ fig, axs = plt.subplots(1, 2, figsize=(12, 5))
232
+ im1 = axs[0].imshow(X[i], extent=[0, 1, 0, 1], origin='lower', cmap='hot')
233
+ axs[0].set_title(f'Initial State (t: {t_values[i]})')
234
+ plt.colorbar(im1, ax=axs[0])
235
+
236
+ im2 = axs[1].imshow(X_tau[i], extent=[0, 1, 0, 1], origin='lower', cmap='hot')
237
+ axs[1].set_title(f'Shifted State (t + tau): {t_values[i]+tau_values[i]*dt}')
238
+ plt.colorbar(im2, ax=axs[1])
239
+
240
+ fig.suptitle(f'Tau: {tau_values[i]}, Alpha: {alpha_values[i]:.4f}')
241
+ plt.show()
242
+
243
+
244
+ def load_from_path(path):
245
+ dataset_train_path = os.path.join(path, 'dataset_train.pkl')
246
+ dataset_val_path = os.path.join(path, 'dataset_val.pkl')
247
+ alpha_interval_split_path = os.path.join(path, 'alpha_interval_split.json')
248
+ tau_interval_split_path = os.path.join(path, 'tau_interval_split.json')
249
+
250
+ with open(dataset_train_path, 'rb') as f:
251
+ dataset_train = pickle.load(f)
252
+ with open(dataset_val_path, 'rb') as f:
253
+ dataset_val = pickle.load(f)
254
+ with open(alpha_interval_split_path, 'r') as f:
255
+ alpha_interval_split = json.load(f)
256
+ alpha_interval_split = IntervalSplit(**alpha_interval_split)
257
+ with open(tau_interval_split_path, 'r') as f:
258
+ tau_interval_split = json.load(f)
259
+ tau_interval_split = IntervalSplit(**tau_interval_split)
260
+
261
+ return dataset_train, dataset_val, alpha_interval_split, tau_interval_split
data_burgers.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import pickle
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ from dataclasses import dataclass, asdict
8
+ import json
9
+
10
+
11
+ # Rnum = 1000
12
+ num_time_steps = 500
13
+ # # x = np.linspace(0.0,1.0,num=128)
14
+ # dx = 1.0/np.shape(x)[0]
15
+ # TSTEPS = np.linspace(0.0,2.0,num=num_time_steps)
16
+ # dt = 2.0/np.shape(TSTEPS)[0]
17
+
18
+ def get_dt(num_time_steps):
19
+ return 2.0/num_time_steps
20
+
21
+ dt = get_dt(num_time_steps)
22
+
23
+
24
+ def exact_solution(Rnum,t):
25
+ x = np.linspace(0.0,1.0,num=128)
26
+ t0 = np.exp(Rnum/8.0)
27
+ return (x/(t+1))/(1.0+np.sqrt((t+1)/t0)*np.exp(Rnum*(x*x)/(4.0*t+4)))
28
+
29
+
30
+ class ReDataset:
31
+ def __init__(self,
32
+ X: np.ndarray = None,
33
+ X_tau: np.ndarray = None,
34
+ t_values: np.ndarray = None,
35
+ tau_values: np.ndarray = None,
36
+ Re_values: np.ndarray = None):
37
+ self.X = X
38
+ self.X_tau = X_tau
39
+ self.t_values = t_values
40
+ self.tau_values = tau_values
41
+ self.Re_values = Re_values
42
+
43
+ def append(self, other):
44
+ self.X = np.concatenate([self.X, other.X]) if self.X is not None else other.X
45
+ self.X_tau = np.concatenate([self.X_tau, other.X_tau]) if self.X_tau is not None else other.X_tau
46
+ self.t_values = np.concatenate([self.t_values, other.t_values]) if self.t_values is not None else other.t_values
47
+ self.tau_values = np.concatenate([self.tau_values, other.tau_values]) if self.tau_values is not None else other.tau_values
48
+ self.Re_values = np.concatenate([self.Re_values, other.Re_values]) if self.Re_values is not None else other.Re_values
49
+
50
+ @dataclass
51
+ class IntervalSplit:
52
+ interpolation: tuple
53
+ extrapolation_left: tuple
54
+ extrapolation_right: tuple
55
+
56
+ def get_time_shifts(snapshots, tau_range=(100, 500), n_samples=100):
57
+ X = []
58
+ X_tau = []
59
+ tau_values = []
60
+ while len(X) < n_samples:
61
+ tau = np.random.randint(*tau_range)
62
+ i = np.random.randint(0, len(snapshots)-tau)
63
+ X.append(snapshots[i])
64
+ X_tau.append(snapshots[i+tau])
65
+ tau_values.append(tau)
66
+ X = np.array(X)
67
+ X_tau = np.array(X_tau)
68
+ tau_values = np.array(tau_values)
69
+ return X, X_tau, tau_values
70
+
71
+ def prepare_Re_dataset(Re_range=(100, 2000), tau_range=(500, 1900), dt=dt, n_samples=5000):
72
+ X = []
73
+ X_tau = []
74
+ t_values = []
75
+ tau_values = []
76
+ Re_values = []
77
+ TRANGE = (0,2)
78
+ while len(X) < n_samples:
79
+ # sample Re log uniformly
80
+ logRe = np.random.uniform(np.log(Re_range[0]), np.log(Re_range[1]))
81
+ Re = np.exp(logRe).round().astype(int)
82
+ t = np.random.uniform(*TRANGE)
83
+ x_t = exact_solution(Re, t)
84
+ # print('tau_range', tau_range)
85
+ tau = np.random.randint(*tau_range)
86
+ x_tau = exact_solution(Re, t+(tau*dt))
87
+
88
+ X.append(x_t)
89
+ X_tau.append(x_tau)
90
+ t_values.append(t)
91
+ tau_values.append(tau)
92
+ Re_values.append(Re)
93
+
94
+ X = np.array(X)
95
+ X_tau = np.array(X_tau)
96
+ t_values = np.array(t_values)
97
+ tau_values = np.array(tau_values)
98
+ Re_values = np.array(Re_values)
99
+ # return X, X_tau, tau_values, Re_values
100
+ dataset = ReDataset(X, X_tau, t_values, tau_values, Re_values)
101
+ return dataset
102
+
103
+
104
+ def train_test_split_range(interval, interpolation_span=0.1, extrapolation_left_span=0.1, extrapolation_right_span=0.1):
105
+ """
106
+ Split the range into train and test ranges
107
+ We have three test folds:
108
+ 1. Interpolation fold: Re and tau values are within the training (min, max) range but not in the training set
109
+ We sample an interval of length x_interpolation_span% randomly from the total range
110
+ 2. Extrapolation fold: Re and tau values are outside the training (min, max) range
111
+ We sample two intervals of length x_extrapolation_right_span% and x_extrapolation_left_span% from the total range
112
+ 3. Validation fold: Re and tau values are randomly sampled from the total set
113
+
114
+ Overall interval looks like:
115
+ Extrapolation_left_test | normal | Interpolation_test | normal | Extrapolation_right_test
116
+ (min, extrapolation_left) | (extraplation_left, interpolation_min) | (interpolation_min, interpolation_max) | (interpolation_max, extrapolation_right) | (extrapolation_right, max)
117
+ and
118
+ train, val = split(normal, val_split)
119
+ """
120
+ r_min, r_max = interval
121
+ length = (r_max-r_min)
122
+ extra_left_length = extrapolation_left_span * length
123
+ extra_right_length = extrapolation_right_span * length
124
+ inter_length = interpolation_span * length
125
+
126
+ extrapolation_left = (r_min, r_min + extra_left_length)
127
+ extrapolation_right = (r_max - extra_right_length, r_max)
128
+
129
+ interpolation_min = np.random.uniform(extrapolation_left[1], extrapolation_right[0] - inter_length)
130
+ interpolation = (interpolation_min, interpolation_min + inter_length)
131
+
132
+ train_ranges = [(extrapolation_left[1], interpolation[0]), (interpolation[1], extrapolation_right[0])]
133
+ return IntervalSplit(interpolation, extrapolation_left, extrapolation_right), train_ranges
134
+
135
+ def get_train_ranges(interval_split):
136
+ return [
137
+ (interval_split.extrapolation_left[1], interval_split.interpolation[0]),
138
+ (interval_split.interpolation[1], interval_split.extrapolation_right[0])
139
+ ]
140
+
141
+ # def get_dataset_from_ranges(train_ranges):
142
+ # dataset = ReDataset()
143
+ # for re_train_range, tau_train_range in zip(Re_train_ranges, tau_train_ranges):
144
+ # train_dataset = prepare_Re_dataset(Re_range=re_train_range, tau_range=tau_train_range, n_samples=n_samples_train)
145
+ # dataset.append(train_dataset)
146
+
147
+ # return dataset
148
+
149
+
150
+ def get_train_val_test_folds(Re_range, tau_range,
151
+ re_interpolation_span=0.10,
152
+ re_extrapolation_left_span=0.1,
153
+ re_extrapolation_right_span=0.10,
154
+ tau_interpolation_span=0.10,
155
+ tau_extrapolation_left_span=0.1,
156
+ tau_extrapolation_right_span=0.10,
157
+ n_samples_train=1000000,
158
+ val_split=0.2):
159
+
160
+ Re_interval_split, Re_train_ranges = train_test_split_range(Re_range, re_interpolation_span, re_extrapolation_left_span, re_extrapolation_right_span)
161
+ tau_interval_split, tau_train_ranges = train_test_split_range(tau_range, tau_interpolation_span, tau_extrapolation_left_span, tau_extrapolation_right_span)
162
+
163
+ # print(Re_interval_split, Re_train_ranges)
164
+ # print(tau_interval_split, tau_train_ranges)
165
+ # prepare train dataset
166
+ dataset = ReDataset()
167
+ for re_train_range, tau_train_range in zip(Re_train_ranges, tau_train_ranges):
168
+ train_dataset = prepare_Re_dataset(Re_range=re_train_range, tau_range=tau_train_range, n_samples=n_samples_train)
169
+ dataset.append(train_dataset)
170
+
171
+ inds = np.arange(len(dataset.X))
172
+ np.random.shuffle(inds)
173
+ train_inds = inds[:int(len(inds)*(1-val_split))]
174
+ val_inds = inds[int(len(inds)*(1-val_split)):]
175
+ dataset_train = ReDataset(dataset.X[train_inds], dataset.X_tau[train_inds], dataset.t_values[train_inds], dataset.tau_values[train_inds], dataset.Re_values[train_inds])
176
+ dataset_val = ReDataset(dataset.X[val_inds], dataset.X_tau[val_inds],dataset.t_values[val_inds], dataset.tau_values[val_inds], dataset.Re_values[val_inds])
177
+ return dataset_train, dataset_val, Re_interval_split, tau_interval_split
178
+
179
+ def plot_sample(dataset, i):
180
+ X = dataset.X
181
+ X_tau = dataset.X_tau
182
+ Tau = dataset.tau_values
183
+ Re_total = dataset.Re_values
184
+ plt.plot(X[i], label = "Initial State")
185
+ plt.plot(X_tau[i], label = "Mapped State")
186
+ plt.title(f'Tau: {Tau[i]}, Re: {Re_total[i]}')
187
+ plt.legend()
188
+ plt.show()
189
+
190
+ def save_to_path(path, dataset_train, dataset_val, Re_interval_split, tau_interval_split):
191
+ if not os.path.exists(path):
192
+ os.makedirs(path)
193
+ # save dataset_train, dataset_val, Re_interval_split, tau_interval_split to pkl files
194
+ dataset_train_path = os.path.join(path, 'dataset_train.pkl')
195
+ dataset_val_path = os.path.join(path, 'dataset_val.pkl')
196
+ Re_interval_split_path = os.path.join(path, 'Re_interval_split.json')
197
+ tau_interval_split_path = os.path.join(path, 'tau_interval_split.json')
198
+
199
+ with open(dataset_train_path, 'wb') as f:
200
+ pickle.dump(dataset_train, f)
201
+ with open(dataset_val_path, 'wb') as f:
202
+ pickle.dump(dataset_val, f)
203
+
204
+ with open(Re_interval_split_path, 'w') as f:
205
+ json.dump(asdict(Re_interval_split), f)
206
+ with open(tau_interval_split_path, 'w') as f:
207
+ json.dump(asdict(tau_interval_split), f)
208
+
209
+ def load_from_path(path):
210
+ dataset_train_path = os.path.join(path, 'dataset_train.pkl')
211
+ dataset_val_path = os.path.join(path, 'dataset_val.pkl')
212
+ Re_interval_split_path = os.path.join(path, 'Re_interval_split.json')
213
+ tau_interval_split_path = os.path.join(path, 'tau_interval_split.json')
214
+
215
+ with open(dataset_train_path, 'rb') as f:
216
+ dataset_train = pickle.load(f)
217
+ with open(dataset_val_path, 'rb') as f:
218
+ dataset_val = pickle.load(f)
219
+ with open(Re_interval_split_path, 'r') as f:
220
+ Re_interval_split = json.load(f)
221
+ Re_interval_split = IntervalSplit(**Re_interval_split)
222
+ with open(tau_interval_split_path, 'r') as f:
223
+ tau_interval_split = json.load(f)
224
+ tau_interval_split = IntervalSplit(**tau_interval_split)
225
+
226
+ return dataset_train, dataset_val, Re_interval_split, tau_interval_split
227
+
228
+
229
+ def main():
230
+ #Re_range = (100, 3000)
231
+ #num_time_steps = 500
232
+ #tau_range = (175, 425)
233
+ #dataset_train, dataset_val, Re_interval_split, tau_interval_split = get_train_val_test_folds(Re_range, tau_range)
234
+ #save_to_path('data', dataset_train, dataset_val, Re_interval_split, tau_interval_split)
235
+
236
+
237
+
238
+ load_from_path('data')
239
+
240
+
241
+ if __name__ == '__main__':
242
+ main()
243
+
model_adv_dif.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import math
12
+ import torch
13
+
14
+ import os
15
+ import torch
16
+ import torch.nn as nn
17
+ import numpy as np
18
+ import pickle
19
+ from dataclasses import dataclass, asdict
20
+ import json
21
+ from torch.utils.data import DataLoader
22
+
23
+
24
+ # Normalization Layer for Conv2D
25
+ class Norm(nn.Module):
26
+ def __init__(self, num_channels, num_groups=4):
27
+ super(Norm, self).__init__()
28
+ self.norm = nn.GroupNorm(num_groups, num_channels)
29
+
30
+ def forward(self, x):
31
+ return self.norm(x)
32
+
33
+ # Encoder using Conv2D
34
+ class Encoder(nn.Module):
35
+ def __init__(self, latent_dim=3):
36
+ super(Encoder, self).__init__()
37
+ self.conv_layers = nn.Sequential(
38
+ # Input: (batch_size, 1, 256, 256)
39
+ nn.Conv2d(1, 32, kernel_size=2, stride=2, padding=0), # (batch_size, 64, 128, 128)
40
+ nn.GELU(),
41
+ Norm(32),
42
+ nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0), # (batch_size, 128, 64, 64)
43
+ nn.GELU(),
44
+ Norm(64),
45
+ nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0), # (batch_size, 256, 32, 32)
46
+ nn.GELU(),
47
+ Norm(128),
48
+ nn.Conv2d(128, 256, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 16, 16)
49
+ nn.GELU(),
50
+ Norm(256),
51
+ nn.Conv2d(256, 512, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 8, 8)
52
+ nn.GELU(),
53
+ Norm(512),
54
+ )
55
+ self.flatten = nn.Flatten()
56
+ self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
57
+ self.fc_log_var = nn.Linear(512 * 4 * 4, latent_dim)
58
+
59
+ def forward(self, x):
60
+ x = self.conv_layers(x)
61
+ x = self.flatten(x)
62
+ mean = self.fc_mean(x)
63
+ log_var = self.fc_log_var(x)
64
+ return mean, log_var
65
+
66
+
67
+
68
+ class Decoder(nn.Module):
69
+ def __init__(self, latent_dim=3):
70
+ super(Decoder, self).__init__()
71
+ # Fully connected layer to transform the latent vector back to the shape (batch_size, 512, 8, 8)
72
+ self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
73
+
74
+ self.deconv_layers = nn.Sequential(
75
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
76
+ nn.Conv2d(512, 256, kernel_size=1),
77
+ nn.GELU(),
78
+ Norm(256),
79
+
80
+
81
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
82
+ nn.Conv2d(256, 128, kernel_size=1),
83
+ nn.GELU(),
84
+ Norm(128),
85
+
86
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
87
+ nn.Conv2d(128, 64, kernel_size=1),
88
+ nn.GELU(),
89
+ Norm(64),
90
+
91
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
92
+ nn.Conv2d(64, 32, kernel_size=1),
93
+ nn.GELU(),
94
+ Norm(32),
95
+
96
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
97
+ nn.Conv2d(32, 1, kernel_size=1),
98
+ nn.ReLU()
99
+ )
100
+
101
+ def forward(self, z):
102
+ # Transform the latent vector to match the shape of the feature maps
103
+ x = self.fc(z)
104
+ x = x.view(-1, 512, 4, 4) # Reshape to (batch_size, 512, 4, 4)
105
+ x = self.deconv_layers(x)
106
+ return x
107
+
108
+
109
+ class Propagator_concat(nn.Module):
110
+ """
111
+ Takes in (z(t), tau, alpha) and outputs z(t+tau)
112
+ """
113
+ def __init__(self, latent_dim, feats=[16, 32, 64, 32, 16]):
114
+ """
115
+ Initialize the propagator network.
116
+ Input : (z(t), tau)
117
+ Output: z(t+tau)
118
+ """
119
+ super(Propagator_concat, self).__init__()
120
+
121
+ self._net = nn.Sequential(
122
+ nn.Linear(latent_dim + 2, feats[0]), # 1 is for tau; more params will increase this
123
+ nn.GELU(),
124
+ nn.Linear(feats[0], feats[1]),
125
+ nn.GELU(),
126
+ nn.Linear(feats[1], feats[2]),
127
+ nn.GELU(),
128
+ nn.Linear(feats[2], feats[3]),
129
+ nn.GELU(),
130
+ nn.Linear(feats[3], feats[4]),
131
+ nn.GELU(),
132
+ nn.Linear(feats[4], latent_dim),
133
+ )
134
+
135
+ def forward(self, z, tau, alpha):
136
+ """
137
+ Forward pass of the propagator.
138
+ Concatenates latent vector z with tau and processes through the network.
139
+ """
140
+ zproj = z.squeeze(1) # Adjust z dimensions if necessary
141
+ z_ = torch.cat((zproj, tau, alpha), dim=1) # Concatenate z and tau along the last dimension
142
+ z_tau = self._net(z_)
143
+ return z_tau, z_
144
+
145
+
146
+
147
+
148
+ class Model(nn.Module):
149
+ def __init__(self, encoder, decoder, propagator):
150
+ super(Model, self).__init__()
151
+ self.encoder = encoder
152
+ self.decoder = decoder # decoder for x(t)
153
+ self.propagator = propagator # used to time march z(t) to z(t+tau)
154
+
155
+ def reparameterization(self, mean, var):
156
+ epsilon = torch.randn_like(var)
157
+ z = mean + var * epsilon
158
+ return z
159
+
160
+ def forward(self, x, tau, alpha):
161
+ mean, log_var = self.encoder(x)
162
+ z = self.reparameterization(mean, torch.exp(0.5 * log_var))
163
+
164
+ # Update small fcnn to get z(t+tau) from z(t)
165
+ z_tau, z_ = self.propagator(z, tau, alpha)
166
+
167
+ # Reconstruction
168
+ x_hat = self.decoder(z) # Reconstruction of x(t)
169
+ x_hat_tau = self.decoder(z_tau)
170
+
171
+ return x_hat, x_hat_tau, mean, log_var, z_tau, z_
172
+
173
+ def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var):
174
+ """
175
+ Compute the VAE loss components.
176
+ :param x: Original input
177
+ :param x_tau: Future input (ground truth)
178
+ :param x_hat: Reconstructed x(t)
179
+ :param x_hat_tau: Predicted x(t+tau)
180
+ :param mean: Mean of the latent distribution
181
+ :param log_var: Log variance of the latent distribution
182
+ :return: reconstruction_loss1, reconstruction_loss2, KLD
183
+ """
184
+ reconstruction_loss1 = nn.MSELoss()(x, x_hat) # Reconstruction loss for x(t)
185
+ reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau) # Prediction loss for x(t+tau)
186
+
187
+ # Kullback-Leibler Divergence
188
+ KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1)) # Updated dim
189
+
190
+ return reconstruction_loss1, reconstruction_loss2, KLD
model_io_adv_dif.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, asdict
3
+ from data_adv_dif import IntervalSplit
4
+ from config_adv_dif import Config
5
+
6
+
7
+ def save_model(path, model, tau_interval_split, alpha_interval_split, config):
8
+ torch.save({
9
+ 'model_state_dict': model.state_dict(),
10
+ 'alpha_interval_split': asdict(alpha_interval_split),
11
+ 'tau_interval_split': asdict(tau_interval_split),
12
+ 'config': asdict(config),
13
+ }, path)
14
+
15
+
16
+ def load_model(path, model):
17
+ checkpoint = torch.load(path)
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
+ alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
20
+ tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
21
+ config = Config(**checkpoint['config'])
22
+ return model, alpha_interval_split, tau_interval_split, config
model_io_burgers.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, asdict
3
+ from data_burgers import IntervalSplit
4
+ from config_burgers import Config
5
+
6
+
7
+ def save_model(path, model, tau_interval_split, re_interval_split, config):
8
+ torch.save({
9
+ 'model_state_dict': model.state_dict(),
10
+ 're_interval_split': asdict(re_interval_split),
11
+ 'tau_interval_split': asdict(tau_interval_split),
12
+ 'config': asdict(config),
13
+ }, path)
14
+
15
+
16
+ def load_model(path, model):
17
+ checkpoint = torch.load(path)
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
+ re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
20
+ tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
21
+ config = Config(**checkpoint['config'])
22
+ return model, re_interval_split, tau_interval_split, config
model_v2.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import math
12
+ import torch
13
+
14
+
15
+ # In[2]:
16
+
17
+
18
+ def positionalencoding1d(d_model, length):
19
+ """
20
+ :param d_model: dimension of the model
21
+ :param length: length of positions
22
+ :return: length*d_model position matrix
23
+ """
24
+ if d_model % 2 != 0:
25
+ raise ValueError("Cannot use sin/cos positional encoding with "
26
+ "odd dim (got dim={:d})".format(d_model))
27
+ pe = torch.zeros(length, d_model)
28
+ position = torch.arange(0, length).unsqueeze(1)
29
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
30
+ -(math.log(10000.0) / d_model)))
31
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
32
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
33
+
34
+ return pe
35
+
36
+
37
+ # In[3]:
38
+
39
+
40
+ class Norm(nn.Module):
41
+ def __init__(self, num_channels, num_groups=4):
42
+ super(Norm, self).__init__()
43
+ self.norm = nn.GroupNorm(num_groups, num_channels)
44
+
45
+ def forward(self, x):
46
+ return self.norm(x.permute(0,2,1)).permute(0,2,1)
47
+
48
+
49
+ # In[4]:
50
+
51
+ class Norm_new(nn.Module):
52
+ def __init__(self, num_channels, num_groups=4):
53
+ super(Norm_new, self).__init__()
54
+ self.norm = nn.GroupNorm(num_groups, num_channels)
55
+
56
+ def forward(self, x):
57
+ if x.dim() == 2:
58
+ # Reshape to (batch_size, num_channels, 1)
59
+ x = x.unsqueeze(-1)
60
+ x = self.norm(x)
61
+ # Reshape back to (batch_size, num_channels)
62
+ x = x.squeeze(-1)
63
+ else:
64
+ x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1)
65
+ return x
66
+
67
+
68
+ class Encoder(nn.Module):
69
+ def __init__(self, input_dim, latent_dim=2, feats=[512, 256, 128, 64, 32]):
70
+ super(Encoder, self).__init__()
71
+ self.latent_dim = latent_dim
72
+ self._net = nn.Sequential(
73
+ nn.Linear(input_dim, feats[0]),
74
+ nn.GELU(),
75
+ Norm_new(feats[0]),
76
+ nn.Linear(feats[0], feats[1]),
77
+ nn.GELU(),
78
+ Norm_new(feats[1]),
79
+ nn.Linear(feats[1], feats[2]),
80
+ nn.GELU(),
81
+ Norm_new(feats[2]),
82
+ nn.Linear(feats[2], feats[3]),
83
+ nn.GELU(),
84
+ Norm_new(feats[3]),
85
+ nn.Linear(feats[3], feats[4]),
86
+ nn.GELU(),
87
+ Norm_new(feats[4]),
88
+ nn.Linear(feats[4], 2 * latent_dim)
89
+ )
90
+
91
+ def forward(self, x):
92
+ Z = self._net(x)
93
+ mean, log_var = torch.split(Z, self.latent_dim, dim=-1)
94
+ return mean, log_var
95
+
96
+
97
+ # In[5]:
98
+
99
+
100
+ class Decoder(nn.Module):
101
+ def __init__(self, latent_dim, output_dim, feats=[32, 64, 128, 256, 512]):
102
+ super(Decoder, self).__init__()
103
+ self.output_dim = output_dim
104
+ self._net = nn.Sequential(
105
+ nn.Linear(latent_dim, feats[0]),
106
+ nn.GELU(),
107
+ Norm_new(feats[0]),
108
+ nn.Linear(feats[0], feats[1]),
109
+ nn.GELU(),
110
+ Norm_new(feats[1]),
111
+ nn.Linear(feats[1], feats[2]),
112
+ nn.GELU(),
113
+ Norm_new(feats[2]),
114
+ nn.Linear(feats[2], feats[3]),
115
+ nn.GELU(),
116
+ Norm_new(feats[3]),
117
+ nn.Linear(feats[3], feats[4]),
118
+ nn.GELU(),
119
+ Norm_new(feats[4]),
120
+ nn.Linear(feats[4], output_dim),
121
+ nn.Tanh()
122
+ )
123
+
124
+ def forward(self, x):
125
+ y = self._net(x)
126
+ return y
127
+
128
+
129
+ # In[6]:
130
+
131
+
132
+ class Propagator(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
133
+ def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64):
134
+
135
+ """
136
+ Input : (z(t), tau)
137
+ Output: z(t+tau)
138
+ """
139
+ self.max_tau = max_tau
140
+ super(Propagator, self).__init__()
141
+ self.register_buffer('encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64
142
+
143
+ self.projector = nn.Sequential(
144
+ nn.Linear(latent_dim, encoding_dim),
145
+ nn.ReLU(),
146
+ Norm(encoding_dim),
147
+ nn.Linear(encoding_dim, encoding_dim),
148
+ )
149
+
150
+ self._net = nn.Sequential(
151
+ nn.Linear(encoding_dim, feats[0]),
152
+ nn.ReLU(),
153
+ Norm(feats[0]),
154
+ nn.Linear(feats[0], feats[1]),
155
+ nn.ReLU(),
156
+ Norm(feats[1]),
157
+ nn.Linear(feats[1], latent_dim),
158
+ )
159
+
160
+
161
+ def forward(self, z, tau):
162
+ zproj = self.projector(z)
163
+ enc = self.encodings[tau.long()]
164
+ # z: 2
165
+ # enc: 64
166
+ # [z1, z2, enc1, enc2, ..., enc64]
167
+ z = zproj + enc
168
+
169
+ z_tau = self._net(z)
170
+ return z_tau
171
+
172
+
173
+ # Doing this for the embedding for Re
174
+ class Propagator_encoding(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
175
+ def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64, max_re = 5000):
176
+
177
+ """
178
+ Input : (z(t), tau, re)
179
+ Output: z(t+tau)
180
+ """
181
+ self.max_tau = max_tau
182
+ self.max_re = max_re
183
+ super(Propagator_encoding, self).__init__()
184
+ self.register_buffer('tau_encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64
185
+ self.register_buffer('re_encodings', positionalencoding1d(encoding_dim, max_re)) # shape: max_re, 64
186
+
187
+ self.projector = nn.Sequential(
188
+ nn.Linear(latent_dim, encoding_dim),
189
+ nn.ReLU(),
190
+ Norm(encoding_dim),
191
+ nn.Linear(encoding_dim, encoding_dim),
192
+ )
193
+
194
+ self._net = nn.Sequential(
195
+ nn.Linear(encoding_dim, feats[0]),
196
+ nn.ReLU(),
197
+ Norm(feats[0]),
198
+ nn.Linear(feats[0], feats[1]),
199
+ nn.ReLU(),
200
+ Norm(feats[1]),
201
+ nn.Linear(feats[1], latent_dim),
202
+ )
203
+
204
+
205
+ def forward(self, z, tau, re):
206
+ zproj = self.projector(z)
207
+ tau_enc = self.tau_encodings[tau.long()]
208
+ re_enc = self.re_encodings[re.long()]
209
+ # z: 2
210
+ # enc: 64
211
+ # [z1, z2, enc1, enc2, ..., enc64]
212
+ z = zproj + tau_enc + re_enc
213
+ #print("shape after enc addition: ", z.shape)
214
+ z_tau = self._net(z)
215
+ #print("shape z_tau: ", z_tau.shape)
216
+ return z_tau
217
+
218
+
219
+
220
+ class Propagator_concat(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
221
+ def __init__(self, latent_dim, feats = [16, 32]):
222
+
223
+ """
224
+ Input : (z(t), tau, re)
225
+ Output: z(t+tau)
226
+ """
227
+ super(Propagator_concat, self).__init__()
228
+
229
+ self._net = nn.Sequential(
230
+ nn.Linear(latent_dim + 2, feats[0]),
231
+ nn.ReLU(),
232
+ #Norm(feats[1]),
233
+ nn.Linear(feats[0], feats[1]),
234
+ nn.ReLU(),
235
+ #Norm(feats[2]),
236
+ nn.Linear(feats[1], latent_dim),
237
+ )
238
+
239
+ def forward(self, z, tau, re):
240
+ zproj = z.squeeze(1)
241
+ z_ = torch.cat((zproj, tau, re), dim = 1)
242
+ z_tau = self._net(z_)
243
+ z_tau = z_tau[:, None, :]
244
+
245
+ return z_tau
246
+
247
+
248
+
249
+ class Propagator_concat_one_step(nn.Module): #taken in (z(t), Re) and outputs z(t+tau) [2, 5, 10, 2]
250
+ def __init__(self, latent_dim, feats = [16, 32]):
251
+
252
+ """
253
+ Input : (z(t), re)
254
+ Output: z(t+1*dt)
255
+ """
256
+ super(Propagator_concat_one_step, self).__init__()
257
+
258
+ self._net = nn.Sequential(
259
+ nn.Linear(latent_dim + 1, feats[0]),
260
+ nn.ReLU(),
261
+ #Norm(feats[1]),
262
+ nn.Linear(feats[0], feats[1]),
263
+ nn.Tanh(),
264
+ #Norm(feats[2]),
265
+ nn.Linear(feats[1], latent_dim),
266
+ )
267
+
268
+ def forward(self, z, re):
269
+ #zproj = z.squeeze(1)
270
+ zproj = z
271
+ z_ = torch.cat((zproj, re), dim = 1)
272
+ z_tau = self._net(z_)
273
+ #z_tau = z_tau[:, None, :]
274
+
275
+ return z_tau
276
+
277
+
278
+
279
+ class Model(nn.Module):
280
+ def __init__(self, encoder, decoder, propagator):
281
+ super(Model, self).__init__()
282
+ self.encoder = encoder
283
+ self.decoder = decoder # decoder for x(t)
284
+ self.propagator = propagator # used to time march z(t) to z(t+tau)
285
+
286
+ def reparameterization(self, mean, var):
287
+ epsilon = torch.randn_like(var)
288
+ z = mean + var * epsilon
289
+ return z
290
+
291
+ def forward(self, x, tau, re):
292
+ mean, log_var = self.encoder(x)
293
+ z = self.reparameterization(mean, torch.exp(0.5 * log_var))
294
+
295
+ # Update small fcnn to get z(t+tau) from z(t)
296
+ z_tau = self.propagator(z, tau, re)
297
+
298
+ # Reconstruction
299
+ x_hat = self.decoder(z) # Reconstruction of x(t)
300
+ x_hat_tau = self.decoder(z_tau)
301
+
302
+ return x_hat, x_hat_tau, mean, log_var, z_tau
303
+
304
+
305
+ class Model_One_Step(nn.Module): # Only takes in X and Re as the parameter and not the tau as tau = 1
306
+ def __init__(self, encoder, decoder, propagator):
307
+ super(Model_One_Step, self).__init__()
308
+ self.encoder = encoder
309
+ self.decoder = decoder # decoder for x(t)
310
+ self.propagator = propagator # used to time march z(t) to z(t+tau)
311
+
312
+ def reparameterization(self, mean, var):
313
+ epsilon = torch.randn_like(var)
314
+ z = mean + var * epsilon
315
+ return z
316
+
317
+ def forward(self, x, re):
318
+ mean, log_var = self.encoder(x)
319
+ z = self.reparameterization(mean, torch.exp(0.5 * log_var))
320
+
321
+ # Update small fcnn to get z(t+1*dt) from z(t) -- We will use the Propagator_concat_one_step here!
322
+ z_tau = self.propagator(z, re)
323
+
324
+ # Reconstruction
325
+ x_hat = self.decoder(z) # Reconstruction of x(t)
326
+ x_hat_tau = self.decoder(z_tau)
327
+
328
+ return x_hat, x_hat_tau, mean, log_var, z_tau
329
+
330
+
331
+
332
+ class Model_reproduce(nn.Module):
333
+ def __init__(self, encoder, decoder):
334
+ super(Model_reproduce, self).__init__()
335
+ self.encoder = encoder
336
+ self.decoder = decoder # decoder for x(t)
337
+
338
+ def reparameterization(self, mean, var):
339
+ epsilon = torch.randn_like(var)
340
+ z = mean + var * epsilon
341
+ return z
342
+
343
+ def forward(self, x):
344
+ mean, log_var = self.encoder(x)
345
+ z = self.reparameterization(mean, torch.exp(0.5 * log_var))
346
+
347
+ # Reconstruction
348
+ x_hat = self.decoder(z) # Reconstruction of x(t)
349
+
350
+ return x_hat, mean, log_var
351
+
352
+
353
+ # Define loss function
354
+ def loss_function_reproduce(x, x_hat, mean, log_var):
355
+ reconstruction_loss1 = nn.MSELoss()(x, x_hat)
356
+
357
+ KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2))
358
+ return reconstruction_loss1, KLD
359
+
360
+
361
+ # Define loss function
362
+ def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var):
363
+ reconstruction_loss1 = nn.MSELoss()(x, x_hat)
364
+ reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau)
365
+
366
+ KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2))
367
+ return reconstruction_loss1, reconstruction_loss2, KLD
368
+
369
+
370
+ def count_parameters(model):
371
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
372
+ return total_params
373
+
374
+
375
+
376
+
377
+
378
+
379
+
380
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ gradio
4
+ matplotlib
5
+ dataclasses
6
+ json5
7
+