Spaces:
Sleeping
Sleeping
Khalid Rafiq
commited on
Commit
·
ab72d17
1
Parent(s):
ad06c72
Add all required modules and requirements.txt
Browse files- LSTM_model.py +249 -0
- __pycache__/LSTM_model.cpython-310.pyc +0 -0
- __pycache__/config_adv_dif.cpython-310.pyc +0 -0
- __pycache__/config_burgers.cpython-310.pyc +0 -0
- __pycache__/data_adv_dif.cpython-310.pyc +0 -0
- __pycache__/data_burgers.cpython-310.pyc +0 -0
- __pycache__/model_adv_dif.cpython-310.pyc +0 -0
- __pycache__/model_io_adv_dif.cpython-310.pyc +0 -0
- __pycache__/model_io_burgers.cpython-310.pyc +0 -0
- __pycache__/model_v2.cpython-310.pyc +0 -0
- config_adv_dif.py +25 -0
- config_burgers.py +31 -0
- data_adv_dif.py +261 -0
- data_burgers.py +243 -0
- model_adv_dif.py +190 -0
- model_io_adv_dif.py +22 -0
- model_io_burgers.py +22 -0
- model_v2.py +380 -0
- requirements.txt +7 -0
LSTM_model.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import time
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
|
14 |
+
num_time_steps = 500
|
15 |
+
x = np.linspace(0.0,1.0,num=128)
|
16 |
+
dx = 1.0/np.shape(x)[0]
|
17 |
+
tsteps = np.linspace(0.0,2.0,num=num_time_steps)
|
18 |
+
dt = 2.0/np.shape(tsteps)[0]
|
19 |
+
|
20 |
+
class AE_Encoder(nn.Module):
|
21 |
+
def __init__(self, input_dim, latent_dim=2, feats=[512, 256, 128, 64, 32]):
|
22 |
+
super(AE_Encoder, self).__init__()
|
23 |
+
self.latent_dim = latent_dim
|
24 |
+
self._net = nn.Sequential(
|
25 |
+
nn.Linear(input_dim, feats[0]),
|
26 |
+
nn.GELU(),
|
27 |
+
nn.Linear(feats[0], feats[1]),
|
28 |
+
nn.GELU(),
|
29 |
+
nn.Linear(feats[1], feats[2]),
|
30 |
+
nn.GELU(),
|
31 |
+
nn.Linear(feats[2], feats[3]),
|
32 |
+
nn.GELU(),
|
33 |
+
nn.Linear(feats[3], feats[4]),
|
34 |
+
nn.GELU(),
|
35 |
+
nn.Linear(feats[4], latent_dim)
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
Z = self._net(x)
|
40 |
+
return Z
|
41 |
+
|
42 |
+
|
43 |
+
class AE_Decoder(nn.Module):
|
44 |
+
def __init__(self, latent_dim, output_dim, feats=[32, 64, 128, 256, 512]):
|
45 |
+
super(AE_Decoder, self).__init__()
|
46 |
+
self.output_dim = output_dim
|
47 |
+
self._net = nn.Sequential(
|
48 |
+
nn.Linear(latent_dim, feats[0]),
|
49 |
+
nn.GELU(),
|
50 |
+
nn.Linear(feats[0], feats[1]),
|
51 |
+
nn.GELU(),
|
52 |
+
nn.Linear(feats[1], feats[2]),
|
53 |
+
nn.GELU(),
|
54 |
+
nn.Linear(feats[2], feats[3]),
|
55 |
+
nn.GELU(),
|
56 |
+
nn.Linear(feats[3], feats[4]),
|
57 |
+
nn.GELU(),
|
58 |
+
nn.Linear(feats[4], output_dim),
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
y = self._net(x)
|
63 |
+
return y
|
64 |
+
|
65 |
+
|
66 |
+
class AE_Model(nn.Module):
|
67 |
+
def __init__(self, encoder, decoder):
|
68 |
+
super(AE_Model, self).__init__()
|
69 |
+
self.encoder = encoder
|
70 |
+
self.decoder = decoder # decoder for x(t)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
z = self.encoder(x)
|
74 |
+
# Reconstruction
|
75 |
+
x_hat = self.decoder(z) # Reconstruction of x(t)
|
76 |
+
|
77 |
+
return x_hat
|
78 |
+
|
79 |
+
|
80 |
+
class PytorchLSTM(nn.Module):
|
81 |
+
def __init__(self, input_dim=3, hidden_dim=40, output_dim=2):
|
82 |
+
super().__init__()
|
83 |
+
# First LSTM: simulates return_sequences=True
|
84 |
+
self.lstm1 = nn.LSTM(input_dim, hidden_dim, batch_first=True)
|
85 |
+
# Second LSTM: simulates return_sequences=False
|
86 |
+
self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
|
87 |
+
# Dense layer
|
88 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
"""
|
92 |
+
x shape: [batch_size, time_window, input_dim]
|
93 |
+
"""
|
94 |
+
# LSTM1 (return_sequences=True)
|
95 |
+
out1, (h1, c1) = self.lstm1(x)
|
96 |
+
# out1 shape: [batch_size, time_window, hidden_dim]
|
97 |
+
|
98 |
+
# LSTM2 (return_sequences=False -> we only use the last time step)
|
99 |
+
out2, (h2, c2) = self.lstm2(out1)
|
100 |
+
# out2 shape: [batch_size, time_window, hidden_dim]
|
101 |
+
# Last timestep (since we didn't set return_sequences=True)
|
102 |
+
# is effectively out2[:, -1, :], but PyTorch LSTM always returns full seq unless you slice.
|
103 |
+
|
104 |
+
last_timestep = out2[:, -1, :] # shape: [batch_size, hidden_dim]
|
105 |
+
|
106 |
+
# Dense -> 2 outputs
|
107 |
+
output = self.fc(last_timestep) # shape: [batch_size, 2]
|
108 |
+
return output
|
109 |
+
|
110 |
+
def measure_lstm_prediction_time(
|
111 |
+
decoder,
|
112 |
+
lstm_model,
|
113 |
+
lstm_testing_data,
|
114 |
+
sim_num,
|
115 |
+
final_time,
|
116 |
+
time_window=40
|
117 |
+
):
|
118 |
+
"""
|
119 |
+
Predicts up to `final_time` in a walk-forward manner for simulation `sim_num`,
|
120 |
+
measures the elapsed time, and returns the final predicted latent + the true latent.
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
decoder : torch.nn.Module
|
125 |
+
The trained weights of the decoder
|
126 |
+
model : torch.nn.Module
|
127 |
+
Trained PyTorch LSTM model. We'll set model.eval() inside.
|
128 |
+
lstm_testing_data : np.ndarray
|
129 |
+
Shape (num_test_snapshots, num_time_steps, 3).
|
130 |
+
The last dimension typically holds (2 latents + 1 param) or similar.
|
131 |
+
sim_num : int
|
132 |
+
Which simulation index to use (e.g., 0 for the first).
|
133 |
+
final_time : int
|
134 |
+
The final timestep index you want to predict up to (>= time_window).
|
135 |
+
For example, if time_window=10 and final_time=20, we will predict from t=10..19.
|
136 |
+
time_window : int
|
137 |
+
Size of the rolling window (default=40).
|
138 |
+
|
139 |
+
Returns
|
140 |
+
-------
|
141 |
+
float
|
142 |
+
Elapsed time (seconds) for performing the predictions from t=time_window up to t=final_time.
|
143 |
+
np.ndarray
|
144 |
+
The final predicted latent at time=final_time (shape (2,)).
|
145 |
+
np.ndarray
|
146 |
+
The true latent at time=final_time (shape (2,)).
|
147 |
+
"""
|
148 |
+
|
149 |
+
# Basic shape info
|
150 |
+
num_time_steps = lstm_testing_data.shape[1]
|
151 |
+
if final_time > num_time_steps:
|
152 |
+
raise ValueError(
|
153 |
+
f"final_time={final_time} exceeds available time steps={num_time_steps}."
|
154 |
+
)
|
155 |
+
if final_time < time_window:
|
156 |
+
raise ValueError(
|
157 |
+
f"final_time={final_time} is less than time_window={time_window}, no prediction needed."
|
158 |
+
)
|
159 |
+
|
160 |
+
# Initialize the rolling window with first `time_window` steps
|
161 |
+
input_seq = np.zeros((1, time_window, 3), dtype=np.float32)
|
162 |
+
input_seq[0, :, :] = lstm_testing_data[sim_num, 0:time_window, :]
|
163 |
+
|
164 |
+
lstm_model.eval() # inference mode
|
165 |
+
|
166 |
+
final_pred = None # store the final predicted latent
|
167 |
+
start_time = time.time()
|
168 |
+
|
169 |
+
with torch.no_grad():
|
170 |
+
# Predict from t=time_window to t=final_time-1
|
171 |
+
# so that at the end of the loop we've generated a prediction for index final_time.
|
172 |
+
# If you want the model's prediction at final_time itself, we do a loop up to final_time.
|
173 |
+
for t in range(time_window, final_time):
|
174 |
+
inp_tensor = torch.from_numpy(input_seq).float() # shape [1, 10, 3]
|
175 |
+
pred = lstm_model(inp_tensor) # shape [1, 2]
|
176 |
+
pred_np = pred.numpy()[0, :] # shape (2,)
|
177 |
+
|
178 |
+
# Shift the rolling window
|
179 |
+
temp = input_seq[0, 1:time_window, :].copy()
|
180 |
+
input_seq[0, 0:time_window - 1, :] = temp
|
181 |
+
input_seq[0, time_window - 1, 0:2] = pred_np
|
182 |
+
|
183 |
+
# Keep track of the last prediction
|
184 |
+
final_pred = pred_np
|
185 |
+
|
186 |
+
x_hat_tau_pred = decoder(torch.tensor(final_pred, dtype = torch.float32))
|
187 |
+
|
188 |
+
end_time = time.time()
|
189 |
+
|
190 |
+
elapsed = end_time - start_time
|
191 |
+
|
192 |
+
# final_pred is the LSTM's predicted latent for step `final_time`.
|
193 |
+
# The *true* latent at that time is:
|
194 |
+
final_true = lstm_testing_data[sim_num, final_time, 0:2] # shape (2,)
|
195 |
+
|
196 |
+
return elapsed, final_pred, final_true
|
197 |
+
|
198 |
+
|
199 |
+
def collect_snapshots(Rnum):
|
200 |
+
snapshot_matrix = np.zeros(shape=(np.shape(x)[0],np.shape(tsteps)[0]))
|
201 |
+
|
202 |
+
trange = np.arange(np.shape(tsteps)[0])
|
203 |
+
for t in trange:
|
204 |
+
snapshot_matrix[:,t] = exact_solution(Rnum,tsteps[t])[:]
|
205 |
+
|
206 |
+
return snapshot_matrix
|
207 |
+
|
208 |
+
def collect_multiparam_snapshots_train():
|
209 |
+
rnum_vals = np.arange(900,2900,100)
|
210 |
+
|
211 |
+
rsnap = 0
|
212 |
+
for rnum_val in rnum_vals:
|
213 |
+
snapshots_temp = np.transpose(collect_snapshots(rnum_val))
|
214 |
+
|
215 |
+
if rsnap == 0:
|
216 |
+
all_snapshots = snapshots_temp
|
217 |
+
else:
|
218 |
+
|
219 |
+
all_snapshots = np.concatenate((all_snapshots,snapshots_temp),axis=0)
|
220 |
+
|
221 |
+
rsnap = rsnap + 1
|
222 |
+
return all_snapshots, rnum_vals/1000
|
223 |
+
|
224 |
+
def collect_multiparam_snapshots_test():
|
225 |
+
rnum_vals = np.arange(1050,2850,200)
|
226 |
+
|
227 |
+
rsnap = 0
|
228 |
+
for rnum_val in rnum_vals:
|
229 |
+
snapshots_temp = np.transpose(collect_snapshots(rnum_val))
|
230 |
+
|
231 |
+
if rsnap == 0:
|
232 |
+
all_snapshots = snapshots_temp
|
233 |
+
else:
|
234 |
+
|
235 |
+
all_snapshots = np.concatenate((all_snapshots,snapshots_temp),axis=0)
|
236 |
+
|
237 |
+
rsnap = rsnap + 1
|
238 |
+
return all_snapshots, rnum_vals/1000
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
return elapsed, final_pred, final_true
|
243 |
+
|
244 |
+
def exact_solution(Rnum,t):
|
245 |
+
x = np.linspace(0.0,1.0,num=128)
|
246 |
+
t0 = np.exp(Rnum/8.0)
|
247 |
+
return (x/(t+1))/(1.0+np.sqrt((t+1)/t0)*np.exp(Rnum*(x*x)/(4.0*t+4)))
|
248 |
+
|
249 |
+
|
__pycache__/LSTM_model.cpython-310.pyc
ADDED
Binary file (6.23 kB). View file
|
|
__pycache__/config_adv_dif.cpython-310.pyc
ADDED
Binary file (1.03 kB). View file
|
|
__pycache__/config_burgers.cpython-310.pyc
ADDED
Binary file (1.19 kB). View file
|
|
__pycache__/data_adv_dif.cpython-310.pyc
ADDED
Binary file (7.21 kB). View file
|
|
__pycache__/data_burgers.cpython-310.pyc
ADDED
Binary file (7.42 kB). View file
|
|
__pycache__/model_adv_dif.cpython-310.pyc
ADDED
Binary file (5.36 kB). View file
|
|
__pycache__/model_io_adv_dif.cpython-310.pyc
ADDED
Binary file (886 Bytes). View file
|
|
__pycache__/model_io_burgers.cpython-310.pyc
ADDED
Binary file (883 Bytes). View file
|
|
__pycache__/model_v2.cpython-310.pyc
ADDED
Binary file (9.64 kB). View file
|
|
config_adv_dif.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import json
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Config:
|
6 |
+
# default values. DO NOT TOUCH
|
7 |
+
name: str = 'FlexiPropagator_2D'
|
8 |
+
latent_dim: int = 3
|
9 |
+
batch_size: int = 64
|
10 |
+
lr: float = 3e-4
|
11 |
+
num_epochs: int = 25
|
12 |
+
num_time_steps: int = 500
|
13 |
+
|
14 |
+
gamma: float = 3.25
|
15 |
+
beta: float = 1e-3
|
16 |
+
|
17 |
+
val_every: float = 0.25
|
18 |
+
plot_train_every: float = 0.01
|
19 |
+
|
20 |
+
save_dir: str = 'checkpoints'
|
21 |
+
|
22 |
+
def load_config(path):
|
23 |
+
with open(path, 'r') as f:
|
24 |
+
config = json.load(f)
|
25 |
+
return Config(**config)
|
config_burgers.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import json
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Config:
|
6 |
+
# default values. DO NOT TOUCH
|
7 |
+
name: str = 'FlexiPropagator'
|
8 |
+
latent_dim: int = 2
|
9 |
+
input_dim: int = 128
|
10 |
+
batch_size: int = 128
|
11 |
+
lr: float = 3e-4
|
12 |
+
num_epochs: int = 200
|
13 |
+
n_samples_train: int = 8_00_000
|
14 |
+
|
15 |
+
num_time_steps: int = 500
|
16 |
+
|
17 |
+
tau_left_fraction: float = 0.35
|
18 |
+
tau_right_fraction: float = 0.85
|
19 |
+
|
20 |
+
gamma: float = 3.25
|
21 |
+
beta: float = 1e-4
|
22 |
+
|
23 |
+
val_every: float = 0.25
|
24 |
+
plot_train_every: float = 0.01
|
25 |
+
|
26 |
+
save_dir: str = 'checkpoints'
|
27 |
+
|
28 |
+
def load_config(path):
|
29 |
+
with open(path, 'r') as f:
|
30 |
+
config = json.load(f)
|
31 |
+
return Config(**config)
|
data_adv_dif.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
from dataclasses import dataclass, asdict
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
# Rnum = 1000
|
11 |
+
num_time_steps = 500
|
12 |
+
|
13 |
+
def get_dt(num_time_steps):
|
14 |
+
return 2.0/num_time_steps
|
15 |
+
|
16 |
+
dt = get_dt(num_time_steps)
|
17 |
+
|
18 |
+
|
19 |
+
def exact_solution(alpha, t, L=2.0, Nx=128, Ny=128, c=1.0):
|
20 |
+
nu = 1.0 / alpha
|
21 |
+
x_vals = np.linspace(-L, L, Nx)
|
22 |
+
y_vals = np.linspace(-L, L, Ny)
|
23 |
+
X, Y = np.meshgrid(x_vals, y_vals)
|
24 |
+
if t <= 0:
|
25 |
+
return np.zeros_like(X)
|
26 |
+
rx = X - c * t
|
27 |
+
ry = Y
|
28 |
+
r2 = rx**2 + ry**2
|
29 |
+
denominator = 4.0 * nu * t
|
30 |
+
amplitude = 1.0 / (4.0 * np.pi * nu * t)
|
31 |
+
U = amplitude * np.exp(-r2 / denominator)
|
32 |
+
return U
|
33 |
+
|
34 |
+
|
35 |
+
class AdvectionDiffussionDataset:
|
36 |
+
def __init__(self,
|
37 |
+
X: np.ndarray = None,
|
38 |
+
X_tau: np.ndarray = None,
|
39 |
+
t_values: np.ndarray = None,
|
40 |
+
tau_values: np.ndarray = None,
|
41 |
+
alpha_values: np.ndarray = None):
|
42 |
+
self.X = X
|
43 |
+
self.X_tau = X_tau
|
44 |
+
self.t_values = t_values
|
45 |
+
self.tau_values = tau_values
|
46 |
+
self.alpha_values = alpha_values
|
47 |
+
|
48 |
+
def append(self, other):
|
49 |
+
self.X = np.concatenate([self.X, other.X]) if self.X is not None else other.X
|
50 |
+
self.X_tau = np.concatenate([self.X_tau, other.X_tau]) if self.X_tau is not None else other.X_tau
|
51 |
+
self.t_values = np.concatenate([self.t_values, other.t_values]) if self.t_values is not None else other.t_values
|
52 |
+
self.tau_values = np.concatenate([self.tau_values, other.tau_values]) if self.tau_values is not None else other.tau_values
|
53 |
+
self.alpha_values = np.concatenate([self.alpha_values, other.alpha_values]) if self.alpha_values is not None else other.alpha_values
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class IntervalSplit:
|
57 |
+
interpolation: tuple
|
58 |
+
extrapolation_left: tuple
|
59 |
+
extrapolation_right: tuple
|
60 |
+
|
61 |
+
def prepare_adv_diff_dataset(alpha_range=(0.01, 10), tau_range=(150, 400), dt=dt, n_samples=500):
|
62 |
+
X = []
|
63 |
+
X_tau = []
|
64 |
+
t_values = []
|
65 |
+
tau_values = []
|
66 |
+
alpha_values = []
|
67 |
+
TRANGE = (0.01, 2.0)
|
68 |
+
while len(X) < n_samples:
|
69 |
+
# sample alpha uniformly
|
70 |
+
alpha = np.random.uniform(*alpha_range)
|
71 |
+
t = np.random.uniform(*TRANGE)
|
72 |
+
x_t = exact_solution(alpha, t)
|
73 |
+
tau = np.random.randint(*tau_range)
|
74 |
+
x_tau = exact_solution(alpha, t+(tau*dt))
|
75 |
+
|
76 |
+
X.append(x_t)
|
77 |
+
X_tau.append(x_tau)
|
78 |
+
t_values.append(t)
|
79 |
+
tau_values.append(tau)
|
80 |
+
alpha_values.append(alpha)
|
81 |
+
|
82 |
+
X = np.array(X)
|
83 |
+
X_tau = np.array(X_tau)
|
84 |
+
t_values = np.array(t_values)
|
85 |
+
tau_values = np.array(tau_values)
|
86 |
+
alpha_values = np.array(alpha_values)
|
87 |
+
dataset = AdvectionDiffussionDataset(X, X_tau, t_values, tau_values, alpha_values)
|
88 |
+
return dataset
|
89 |
+
|
90 |
+
|
91 |
+
def train_test_split_range(interval, interpolation_span=0.1, extrapolation_left_span=0.1, extrapolation_right_span=0.1):
|
92 |
+
"""
|
93 |
+
Split the range into train and test ranges
|
94 |
+
We have three test folds:
|
95 |
+
1. Interpolation fold: Re and tau values are within the training (min, max) range but not in the training set
|
96 |
+
We sample an interval of length x_interpolation_span% randomly from the total range
|
97 |
+
2. Extrapolation fold: Re and tau values are outside the training (min, max) range
|
98 |
+
We sample two intervals of length x_extrapolation_right_span% and x_extrapolation_left_span% from the total range
|
99 |
+
3. Validation fold: Re and tau values are randomly sampled from the total set
|
100 |
+
|
101 |
+
Overall interval looks like:
|
102 |
+
Extrapolation_left_test | normal | Interpolation_test | normal | Extrapolation_right_test
|
103 |
+
(min, extrapolation_left) | (extraplation_left, interpolation_min) | (interpolation_min, interpolation_max) | (interpolation_max, extrapolation_right) | (extrapolation_right, max)
|
104 |
+
and
|
105 |
+
train, val = split(normal, val_split)
|
106 |
+
"""
|
107 |
+
r_min, r_max = interval
|
108 |
+
length = (r_max-r_min)
|
109 |
+
extra_left_length = extrapolation_left_span * length
|
110 |
+
extra_right_length = extrapolation_right_span * length
|
111 |
+
inter_length = interpolation_span * length
|
112 |
+
|
113 |
+
extrapolation_left = (r_min, r_min + extra_left_length)
|
114 |
+
extrapolation_right = (r_max - extra_right_length, r_max)
|
115 |
+
|
116 |
+
interpolation_min = np.random.uniform(extrapolation_left[1], extrapolation_right[0] - inter_length)
|
117 |
+
interpolation = (interpolation_min, interpolation_min + inter_length)
|
118 |
+
|
119 |
+
train_ranges = [(extrapolation_left[1], interpolation[0]), (interpolation[1], extrapolation_right[0])]
|
120 |
+
return IntervalSplit(interpolation, extrapolation_left, extrapolation_right), train_ranges
|
121 |
+
|
122 |
+
def get_train_ranges(interval_split):
|
123 |
+
return [
|
124 |
+
(interval_split.extrapolation_left[1], interval_split.interpolation[0]),
|
125 |
+
(interval_split.interpolation[1], interval_split.extrapolation_right[0])
|
126 |
+
]
|
127 |
+
|
128 |
+
def get_train_val_test_folds(alpha_range, tau_range,
|
129 |
+
alpha_interpolation_span=0.10,
|
130 |
+
alpha_extrapolation_left_span=0.10,
|
131 |
+
alpha_extrapolation_right_span=0.10,
|
132 |
+
tau_interpolation_span=0.10,
|
133 |
+
tau_extrapolation_left_span=0.10,
|
134 |
+
tau_extrapolation_right_span=0.10,
|
135 |
+
n_samples_train=500,
|
136 |
+
n_samples_val=200):
|
137 |
+
"""
|
138 |
+
Generate train (4 sub-regions) and val (left extrp, interp, right extrp
|
139 |
+
for alpha x left extrp, interp, right extrp for tau) datasets.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
dataset_train : AdvectionDiffussionDataset
|
143 |
+
dataset_val : AdvectionDiffussionDataset
|
144 |
+
alpha_interval_split: IntervalSplit
|
145 |
+
tau_interval_split : IntervalSplit
|
146 |
+
"""
|
147 |
+
|
148 |
+
# ---------------------------------------------------------------------
|
149 |
+
# 1) Split alpha into 4 regions: left extrp, interp, right extrp, train
|
150 |
+
# 2) Split tau into 4 regions: left extrp, interp, right extrp, train
|
151 |
+
# ---------------------------------------------------------------------
|
152 |
+
alpha_interval_split, alpha_train_ranges = train_test_split_range(
|
153 |
+
alpha_range,
|
154 |
+
alpha_interpolation_span,
|
155 |
+
alpha_extrapolation_left_span,
|
156 |
+
alpha_extrapolation_right_span
|
157 |
+
)
|
158 |
+
tau_interval_split, tau_train_ranges = train_test_split_range(
|
159 |
+
tau_range,
|
160 |
+
tau_interpolation_span,
|
161 |
+
tau_extrapolation_left_span,
|
162 |
+
tau_extrapolation_right_span
|
163 |
+
)
|
164 |
+
|
165 |
+
# alpha_train_ranges and tau_train_ranges each have 2 intervals:
|
166 |
+
# alpha_train_ranges = [ (a1_lo, a1_hi), (a2_lo, a2_hi) ]
|
167 |
+
# tau_train_ranges = [ (t1_lo, t1_hi), (t2_lo, t2_hi) ]
|
168 |
+
#
|
169 |
+
# Meanwhile, alpha_interval_split has:
|
170 |
+
# alpha_interval_split.extrapolation_left = (a_left_lo, a_left_hi)
|
171 |
+
# alpha_interval_split.interpolation = (a_int_lo, a_int_hi)
|
172 |
+
# alpha_interval_split.extrapolation_right = (a_right_lo, a_right_hi)
|
173 |
+
# and similarly for tau_interval_split.
|
174 |
+
|
175 |
+
# -------------------------------------------------------------
|
176 |
+
# 3) Build the TRAIN dataset from the Cartesian product
|
177 |
+
# of alpha_train_ranges x tau_train_ranges => 4 combos
|
178 |
+
# -------------------------------------------------------------
|
179 |
+
dataset_train = AdvectionDiffussionDataset()
|
180 |
+
for alpha_subrange in alpha_train_ranges: # 2 intervals
|
181 |
+
for tau_subrange in tau_train_ranges: # 2 intervals
|
182 |
+
subset = prepare_adv_diff_dataset(
|
183 |
+
alpha_range=alpha_subrange,
|
184 |
+
tau_range=tau_subrange,
|
185 |
+
n_samples=n_samples_train
|
186 |
+
)
|
187 |
+
dataset_train.append(subset)
|
188 |
+
|
189 |
+
# -------------------------------------------------------------
|
190 |
+
# 4) Build the VAL dataset from the leftover intervals:
|
191 |
+
# alpha in { left extrp, interp, right extrp }
|
192 |
+
# x tau in { left extrp, interp, right extrp } => up to 9 combos
|
193 |
+
# -------------------------------------------------------------
|
194 |
+
alpha_val_intervals = [
|
195 |
+
alpha_interval_split.extrapolation_left,
|
196 |
+
alpha_interval_split.interpolation,
|
197 |
+
alpha_interval_split.extrapolation_right
|
198 |
+
]
|
199 |
+
tau_val_intervals = [
|
200 |
+
tau_interval_split.extrapolation_left,
|
201 |
+
tau_interval_split.interpolation,
|
202 |
+
tau_interval_split.extrapolation_right
|
203 |
+
]
|
204 |
+
|
205 |
+
dataset_val = AdvectionDiffussionDataset()
|
206 |
+
|
207 |
+
for a_val_range in alpha_val_intervals:
|
208 |
+
for t_val_range in tau_val_intervals:
|
209 |
+
subset_val = prepare_adv_diff_dataset(
|
210 |
+
alpha_range=a_val_range,
|
211 |
+
tau_range=t_val_range,
|
212 |
+
n_samples=n_samples_val
|
213 |
+
)
|
214 |
+
dataset_val.append(subset_val)
|
215 |
+
|
216 |
+
return dataset_train, dataset_val, alpha_interval_split, tau_interval_split
|
217 |
+
|
218 |
+
|
219 |
+
def plot_sample(dataset, i):
|
220 |
+
"""
|
221 |
+
Plot a sample pair from the dataset.
|
222 |
+
"""
|
223 |
+
X = dataset.X
|
224 |
+
X_tau = dataset.X_tau
|
225 |
+
t_values = dataset.t_values
|
226 |
+
tau_values = dataset.tau_values
|
227 |
+
alpha_values = dataset.alpha_values
|
228 |
+
|
229 |
+
print("Shape of X:", X.shape)
|
230 |
+
|
231 |
+
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
|
232 |
+
im1 = axs[0].imshow(X[i], extent=[0, 1, 0, 1], origin='lower', cmap='hot')
|
233 |
+
axs[0].set_title(f'Initial State (t: {t_values[i]})')
|
234 |
+
plt.colorbar(im1, ax=axs[0])
|
235 |
+
|
236 |
+
im2 = axs[1].imshow(X_tau[i], extent=[0, 1, 0, 1], origin='lower', cmap='hot')
|
237 |
+
axs[1].set_title(f'Shifted State (t + tau): {t_values[i]+tau_values[i]*dt}')
|
238 |
+
plt.colorbar(im2, ax=axs[1])
|
239 |
+
|
240 |
+
fig.suptitle(f'Tau: {tau_values[i]}, Alpha: {alpha_values[i]:.4f}')
|
241 |
+
plt.show()
|
242 |
+
|
243 |
+
|
244 |
+
def load_from_path(path):
|
245 |
+
dataset_train_path = os.path.join(path, 'dataset_train.pkl')
|
246 |
+
dataset_val_path = os.path.join(path, 'dataset_val.pkl')
|
247 |
+
alpha_interval_split_path = os.path.join(path, 'alpha_interval_split.json')
|
248 |
+
tau_interval_split_path = os.path.join(path, 'tau_interval_split.json')
|
249 |
+
|
250 |
+
with open(dataset_train_path, 'rb') as f:
|
251 |
+
dataset_train = pickle.load(f)
|
252 |
+
with open(dataset_val_path, 'rb') as f:
|
253 |
+
dataset_val = pickle.load(f)
|
254 |
+
with open(alpha_interval_split_path, 'r') as f:
|
255 |
+
alpha_interval_split = json.load(f)
|
256 |
+
alpha_interval_split = IntervalSplit(**alpha_interval_split)
|
257 |
+
with open(tau_interval_split_path, 'r') as f:
|
258 |
+
tau_interval_split = json.load(f)
|
259 |
+
tau_interval_split = IntervalSplit(**tau_interval_split)
|
260 |
+
|
261 |
+
return dataset_train, dataset_val, alpha_interval_split, tau_interval_split
|
data_burgers.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
from dataclasses import dataclass, asdict
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
# Rnum = 1000
|
12 |
+
num_time_steps = 500
|
13 |
+
# # x = np.linspace(0.0,1.0,num=128)
|
14 |
+
# dx = 1.0/np.shape(x)[0]
|
15 |
+
# TSTEPS = np.linspace(0.0,2.0,num=num_time_steps)
|
16 |
+
# dt = 2.0/np.shape(TSTEPS)[0]
|
17 |
+
|
18 |
+
def get_dt(num_time_steps):
|
19 |
+
return 2.0/num_time_steps
|
20 |
+
|
21 |
+
dt = get_dt(num_time_steps)
|
22 |
+
|
23 |
+
|
24 |
+
def exact_solution(Rnum,t):
|
25 |
+
x = np.linspace(0.0,1.0,num=128)
|
26 |
+
t0 = np.exp(Rnum/8.0)
|
27 |
+
return (x/(t+1))/(1.0+np.sqrt((t+1)/t0)*np.exp(Rnum*(x*x)/(4.0*t+4)))
|
28 |
+
|
29 |
+
|
30 |
+
class ReDataset:
|
31 |
+
def __init__(self,
|
32 |
+
X: np.ndarray = None,
|
33 |
+
X_tau: np.ndarray = None,
|
34 |
+
t_values: np.ndarray = None,
|
35 |
+
tau_values: np.ndarray = None,
|
36 |
+
Re_values: np.ndarray = None):
|
37 |
+
self.X = X
|
38 |
+
self.X_tau = X_tau
|
39 |
+
self.t_values = t_values
|
40 |
+
self.tau_values = tau_values
|
41 |
+
self.Re_values = Re_values
|
42 |
+
|
43 |
+
def append(self, other):
|
44 |
+
self.X = np.concatenate([self.X, other.X]) if self.X is not None else other.X
|
45 |
+
self.X_tau = np.concatenate([self.X_tau, other.X_tau]) if self.X_tau is not None else other.X_tau
|
46 |
+
self.t_values = np.concatenate([self.t_values, other.t_values]) if self.t_values is not None else other.t_values
|
47 |
+
self.tau_values = np.concatenate([self.tau_values, other.tau_values]) if self.tau_values is not None else other.tau_values
|
48 |
+
self.Re_values = np.concatenate([self.Re_values, other.Re_values]) if self.Re_values is not None else other.Re_values
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class IntervalSplit:
|
52 |
+
interpolation: tuple
|
53 |
+
extrapolation_left: tuple
|
54 |
+
extrapolation_right: tuple
|
55 |
+
|
56 |
+
def get_time_shifts(snapshots, tau_range=(100, 500), n_samples=100):
|
57 |
+
X = []
|
58 |
+
X_tau = []
|
59 |
+
tau_values = []
|
60 |
+
while len(X) < n_samples:
|
61 |
+
tau = np.random.randint(*tau_range)
|
62 |
+
i = np.random.randint(0, len(snapshots)-tau)
|
63 |
+
X.append(snapshots[i])
|
64 |
+
X_tau.append(snapshots[i+tau])
|
65 |
+
tau_values.append(tau)
|
66 |
+
X = np.array(X)
|
67 |
+
X_tau = np.array(X_tau)
|
68 |
+
tau_values = np.array(tau_values)
|
69 |
+
return X, X_tau, tau_values
|
70 |
+
|
71 |
+
def prepare_Re_dataset(Re_range=(100, 2000), tau_range=(500, 1900), dt=dt, n_samples=5000):
|
72 |
+
X = []
|
73 |
+
X_tau = []
|
74 |
+
t_values = []
|
75 |
+
tau_values = []
|
76 |
+
Re_values = []
|
77 |
+
TRANGE = (0,2)
|
78 |
+
while len(X) < n_samples:
|
79 |
+
# sample Re log uniformly
|
80 |
+
logRe = np.random.uniform(np.log(Re_range[0]), np.log(Re_range[1]))
|
81 |
+
Re = np.exp(logRe).round().astype(int)
|
82 |
+
t = np.random.uniform(*TRANGE)
|
83 |
+
x_t = exact_solution(Re, t)
|
84 |
+
# print('tau_range', tau_range)
|
85 |
+
tau = np.random.randint(*tau_range)
|
86 |
+
x_tau = exact_solution(Re, t+(tau*dt))
|
87 |
+
|
88 |
+
X.append(x_t)
|
89 |
+
X_tau.append(x_tau)
|
90 |
+
t_values.append(t)
|
91 |
+
tau_values.append(tau)
|
92 |
+
Re_values.append(Re)
|
93 |
+
|
94 |
+
X = np.array(X)
|
95 |
+
X_tau = np.array(X_tau)
|
96 |
+
t_values = np.array(t_values)
|
97 |
+
tau_values = np.array(tau_values)
|
98 |
+
Re_values = np.array(Re_values)
|
99 |
+
# return X, X_tau, tau_values, Re_values
|
100 |
+
dataset = ReDataset(X, X_tau, t_values, tau_values, Re_values)
|
101 |
+
return dataset
|
102 |
+
|
103 |
+
|
104 |
+
def train_test_split_range(interval, interpolation_span=0.1, extrapolation_left_span=0.1, extrapolation_right_span=0.1):
|
105 |
+
"""
|
106 |
+
Split the range into train and test ranges
|
107 |
+
We have three test folds:
|
108 |
+
1. Interpolation fold: Re and tau values are within the training (min, max) range but not in the training set
|
109 |
+
We sample an interval of length x_interpolation_span% randomly from the total range
|
110 |
+
2. Extrapolation fold: Re and tau values are outside the training (min, max) range
|
111 |
+
We sample two intervals of length x_extrapolation_right_span% and x_extrapolation_left_span% from the total range
|
112 |
+
3. Validation fold: Re and tau values are randomly sampled from the total set
|
113 |
+
|
114 |
+
Overall interval looks like:
|
115 |
+
Extrapolation_left_test | normal | Interpolation_test | normal | Extrapolation_right_test
|
116 |
+
(min, extrapolation_left) | (extraplation_left, interpolation_min) | (interpolation_min, interpolation_max) | (interpolation_max, extrapolation_right) | (extrapolation_right, max)
|
117 |
+
and
|
118 |
+
train, val = split(normal, val_split)
|
119 |
+
"""
|
120 |
+
r_min, r_max = interval
|
121 |
+
length = (r_max-r_min)
|
122 |
+
extra_left_length = extrapolation_left_span * length
|
123 |
+
extra_right_length = extrapolation_right_span * length
|
124 |
+
inter_length = interpolation_span * length
|
125 |
+
|
126 |
+
extrapolation_left = (r_min, r_min + extra_left_length)
|
127 |
+
extrapolation_right = (r_max - extra_right_length, r_max)
|
128 |
+
|
129 |
+
interpolation_min = np.random.uniform(extrapolation_left[1], extrapolation_right[0] - inter_length)
|
130 |
+
interpolation = (interpolation_min, interpolation_min + inter_length)
|
131 |
+
|
132 |
+
train_ranges = [(extrapolation_left[1], interpolation[0]), (interpolation[1], extrapolation_right[0])]
|
133 |
+
return IntervalSplit(interpolation, extrapolation_left, extrapolation_right), train_ranges
|
134 |
+
|
135 |
+
def get_train_ranges(interval_split):
|
136 |
+
return [
|
137 |
+
(interval_split.extrapolation_left[1], interval_split.interpolation[0]),
|
138 |
+
(interval_split.interpolation[1], interval_split.extrapolation_right[0])
|
139 |
+
]
|
140 |
+
|
141 |
+
# def get_dataset_from_ranges(train_ranges):
|
142 |
+
# dataset = ReDataset()
|
143 |
+
# for re_train_range, tau_train_range in zip(Re_train_ranges, tau_train_ranges):
|
144 |
+
# train_dataset = prepare_Re_dataset(Re_range=re_train_range, tau_range=tau_train_range, n_samples=n_samples_train)
|
145 |
+
# dataset.append(train_dataset)
|
146 |
+
|
147 |
+
# return dataset
|
148 |
+
|
149 |
+
|
150 |
+
def get_train_val_test_folds(Re_range, tau_range,
|
151 |
+
re_interpolation_span=0.10,
|
152 |
+
re_extrapolation_left_span=0.1,
|
153 |
+
re_extrapolation_right_span=0.10,
|
154 |
+
tau_interpolation_span=0.10,
|
155 |
+
tau_extrapolation_left_span=0.1,
|
156 |
+
tau_extrapolation_right_span=0.10,
|
157 |
+
n_samples_train=1000000,
|
158 |
+
val_split=0.2):
|
159 |
+
|
160 |
+
Re_interval_split, Re_train_ranges = train_test_split_range(Re_range, re_interpolation_span, re_extrapolation_left_span, re_extrapolation_right_span)
|
161 |
+
tau_interval_split, tau_train_ranges = train_test_split_range(tau_range, tau_interpolation_span, tau_extrapolation_left_span, tau_extrapolation_right_span)
|
162 |
+
|
163 |
+
# print(Re_interval_split, Re_train_ranges)
|
164 |
+
# print(tau_interval_split, tau_train_ranges)
|
165 |
+
# prepare train dataset
|
166 |
+
dataset = ReDataset()
|
167 |
+
for re_train_range, tau_train_range in zip(Re_train_ranges, tau_train_ranges):
|
168 |
+
train_dataset = prepare_Re_dataset(Re_range=re_train_range, tau_range=tau_train_range, n_samples=n_samples_train)
|
169 |
+
dataset.append(train_dataset)
|
170 |
+
|
171 |
+
inds = np.arange(len(dataset.X))
|
172 |
+
np.random.shuffle(inds)
|
173 |
+
train_inds = inds[:int(len(inds)*(1-val_split))]
|
174 |
+
val_inds = inds[int(len(inds)*(1-val_split)):]
|
175 |
+
dataset_train = ReDataset(dataset.X[train_inds], dataset.X_tau[train_inds], dataset.t_values[train_inds], dataset.tau_values[train_inds], dataset.Re_values[train_inds])
|
176 |
+
dataset_val = ReDataset(dataset.X[val_inds], dataset.X_tau[val_inds],dataset.t_values[val_inds], dataset.tau_values[val_inds], dataset.Re_values[val_inds])
|
177 |
+
return dataset_train, dataset_val, Re_interval_split, tau_interval_split
|
178 |
+
|
179 |
+
def plot_sample(dataset, i):
|
180 |
+
X = dataset.X
|
181 |
+
X_tau = dataset.X_tau
|
182 |
+
Tau = dataset.tau_values
|
183 |
+
Re_total = dataset.Re_values
|
184 |
+
plt.plot(X[i], label = "Initial State")
|
185 |
+
plt.plot(X_tau[i], label = "Mapped State")
|
186 |
+
plt.title(f'Tau: {Tau[i]}, Re: {Re_total[i]}')
|
187 |
+
plt.legend()
|
188 |
+
plt.show()
|
189 |
+
|
190 |
+
def save_to_path(path, dataset_train, dataset_val, Re_interval_split, tau_interval_split):
|
191 |
+
if not os.path.exists(path):
|
192 |
+
os.makedirs(path)
|
193 |
+
# save dataset_train, dataset_val, Re_interval_split, tau_interval_split to pkl files
|
194 |
+
dataset_train_path = os.path.join(path, 'dataset_train.pkl')
|
195 |
+
dataset_val_path = os.path.join(path, 'dataset_val.pkl')
|
196 |
+
Re_interval_split_path = os.path.join(path, 'Re_interval_split.json')
|
197 |
+
tau_interval_split_path = os.path.join(path, 'tau_interval_split.json')
|
198 |
+
|
199 |
+
with open(dataset_train_path, 'wb') as f:
|
200 |
+
pickle.dump(dataset_train, f)
|
201 |
+
with open(dataset_val_path, 'wb') as f:
|
202 |
+
pickle.dump(dataset_val, f)
|
203 |
+
|
204 |
+
with open(Re_interval_split_path, 'w') as f:
|
205 |
+
json.dump(asdict(Re_interval_split), f)
|
206 |
+
with open(tau_interval_split_path, 'w') as f:
|
207 |
+
json.dump(asdict(tau_interval_split), f)
|
208 |
+
|
209 |
+
def load_from_path(path):
|
210 |
+
dataset_train_path = os.path.join(path, 'dataset_train.pkl')
|
211 |
+
dataset_val_path = os.path.join(path, 'dataset_val.pkl')
|
212 |
+
Re_interval_split_path = os.path.join(path, 'Re_interval_split.json')
|
213 |
+
tau_interval_split_path = os.path.join(path, 'tau_interval_split.json')
|
214 |
+
|
215 |
+
with open(dataset_train_path, 'rb') as f:
|
216 |
+
dataset_train = pickle.load(f)
|
217 |
+
with open(dataset_val_path, 'rb') as f:
|
218 |
+
dataset_val = pickle.load(f)
|
219 |
+
with open(Re_interval_split_path, 'r') as f:
|
220 |
+
Re_interval_split = json.load(f)
|
221 |
+
Re_interval_split = IntervalSplit(**Re_interval_split)
|
222 |
+
with open(tau_interval_split_path, 'r') as f:
|
223 |
+
tau_interval_split = json.load(f)
|
224 |
+
tau_interval_split = IntervalSplit(**tau_interval_split)
|
225 |
+
|
226 |
+
return dataset_train, dataset_val, Re_interval_split, tau_interval_split
|
227 |
+
|
228 |
+
|
229 |
+
def main():
|
230 |
+
#Re_range = (100, 3000)
|
231 |
+
#num_time_steps = 500
|
232 |
+
#tau_range = (175, 425)
|
233 |
+
#dataset_train, dataset_val, Re_interval_split, tau_interval_split = get_train_val_test_folds(Re_range, tau_range)
|
234 |
+
#save_to_path('data', dataset_train, dataset_val, Re_interval_split, tau_interval_split)
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
load_from_path('data')
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == '__main__':
|
242 |
+
main()
|
243 |
+
|
model_adv_dif.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import numpy as np
|
18 |
+
import pickle
|
19 |
+
from dataclasses import dataclass, asdict
|
20 |
+
import json
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
|
23 |
+
|
24 |
+
# Normalization Layer for Conv2D
|
25 |
+
class Norm(nn.Module):
|
26 |
+
def __init__(self, num_channels, num_groups=4):
|
27 |
+
super(Norm, self).__init__()
|
28 |
+
self.norm = nn.GroupNorm(num_groups, num_channels)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
return self.norm(x)
|
32 |
+
|
33 |
+
# Encoder using Conv2D
|
34 |
+
class Encoder(nn.Module):
|
35 |
+
def __init__(self, latent_dim=3):
|
36 |
+
super(Encoder, self).__init__()
|
37 |
+
self.conv_layers = nn.Sequential(
|
38 |
+
# Input: (batch_size, 1, 256, 256)
|
39 |
+
nn.Conv2d(1, 32, kernel_size=2, stride=2, padding=0), # (batch_size, 64, 128, 128)
|
40 |
+
nn.GELU(),
|
41 |
+
Norm(32),
|
42 |
+
nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0), # (batch_size, 128, 64, 64)
|
43 |
+
nn.GELU(),
|
44 |
+
Norm(64),
|
45 |
+
nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0), # (batch_size, 256, 32, 32)
|
46 |
+
nn.GELU(),
|
47 |
+
Norm(128),
|
48 |
+
nn.Conv2d(128, 256, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 16, 16)
|
49 |
+
nn.GELU(),
|
50 |
+
Norm(256),
|
51 |
+
nn.Conv2d(256, 512, kernel_size=2, stride=2, padding=0), # (batch_size, 512, 8, 8)
|
52 |
+
nn.GELU(),
|
53 |
+
Norm(512),
|
54 |
+
)
|
55 |
+
self.flatten = nn.Flatten()
|
56 |
+
self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
|
57 |
+
self.fc_log_var = nn.Linear(512 * 4 * 4, latent_dim)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.conv_layers(x)
|
61 |
+
x = self.flatten(x)
|
62 |
+
mean = self.fc_mean(x)
|
63 |
+
log_var = self.fc_log_var(x)
|
64 |
+
return mean, log_var
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
class Decoder(nn.Module):
|
69 |
+
def __init__(self, latent_dim=3):
|
70 |
+
super(Decoder, self).__init__()
|
71 |
+
# Fully connected layer to transform the latent vector back to the shape (batch_size, 512, 8, 8)
|
72 |
+
self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
|
73 |
+
|
74 |
+
self.deconv_layers = nn.Sequential(
|
75 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
76 |
+
nn.Conv2d(512, 256, kernel_size=1),
|
77 |
+
nn.GELU(),
|
78 |
+
Norm(256),
|
79 |
+
|
80 |
+
|
81 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
82 |
+
nn.Conv2d(256, 128, kernel_size=1),
|
83 |
+
nn.GELU(),
|
84 |
+
Norm(128),
|
85 |
+
|
86 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
87 |
+
nn.Conv2d(128, 64, kernel_size=1),
|
88 |
+
nn.GELU(),
|
89 |
+
Norm(64),
|
90 |
+
|
91 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
92 |
+
nn.Conv2d(64, 32, kernel_size=1),
|
93 |
+
nn.GELU(),
|
94 |
+
Norm(32),
|
95 |
+
|
96 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
98 |
+
nn.ReLU()
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, z):
|
102 |
+
# Transform the latent vector to match the shape of the feature maps
|
103 |
+
x = self.fc(z)
|
104 |
+
x = x.view(-1, 512, 4, 4) # Reshape to (batch_size, 512, 4, 4)
|
105 |
+
x = self.deconv_layers(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class Propagator_concat(nn.Module):
|
110 |
+
"""
|
111 |
+
Takes in (z(t), tau, alpha) and outputs z(t+tau)
|
112 |
+
"""
|
113 |
+
def __init__(self, latent_dim, feats=[16, 32, 64, 32, 16]):
|
114 |
+
"""
|
115 |
+
Initialize the propagator network.
|
116 |
+
Input : (z(t), tau)
|
117 |
+
Output: z(t+tau)
|
118 |
+
"""
|
119 |
+
super(Propagator_concat, self).__init__()
|
120 |
+
|
121 |
+
self._net = nn.Sequential(
|
122 |
+
nn.Linear(latent_dim + 2, feats[0]), # 1 is for tau; more params will increase this
|
123 |
+
nn.GELU(),
|
124 |
+
nn.Linear(feats[0], feats[1]),
|
125 |
+
nn.GELU(),
|
126 |
+
nn.Linear(feats[1], feats[2]),
|
127 |
+
nn.GELU(),
|
128 |
+
nn.Linear(feats[2], feats[3]),
|
129 |
+
nn.GELU(),
|
130 |
+
nn.Linear(feats[3], feats[4]),
|
131 |
+
nn.GELU(),
|
132 |
+
nn.Linear(feats[4], latent_dim),
|
133 |
+
)
|
134 |
+
|
135 |
+
def forward(self, z, tau, alpha):
|
136 |
+
"""
|
137 |
+
Forward pass of the propagator.
|
138 |
+
Concatenates latent vector z with tau and processes through the network.
|
139 |
+
"""
|
140 |
+
zproj = z.squeeze(1) # Adjust z dimensions if necessary
|
141 |
+
z_ = torch.cat((zproj, tau, alpha), dim=1) # Concatenate z and tau along the last dimension
|
142 |
+
z_tau = self._net(z_)
|
143 |
+
return z_tau, z_
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
class Model(nn.Module):
|
149 |
+
def __init__(self, encoder, decoder, propagator):
|
150 |
+
super(Model, self).__init__()
|
151 |
+
self.encoder = encoder
|
152 |
+
self.decoder = decoder # decoder for x(t)
|
153 |
+
self.propagator = propagator # used to time march z(t) to z(t+tau)
|
154 |
+
|
155 |
+
def reparameterization(self, mean, var):
|
156 |
+
epsilon = torch.randn_like(var)
|
157 |
+
z = mean + var * epsilon
|
158 |
+
return z
|
159 |
+
|
160 |
+
def forward(self, x, tau, alpha):
|
161 |
+
mean, log_var = self.encoder(x)
|
162 |
+
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
|
163 |
+
|
164 |
+
# Update small fcnn to get z(t+tau) from z(t)
|
165 |
+
z_tau, z_ = self.propagator(z, tau, alpha)
|
166 |
+
|
167 |
+
# Reconstruction
|
168 |
+
x_hat = self.decoder(z) # Reconstruction of x(t)
|
169 |
+
x_hat_tau = self.decoder(z_tau)
|
170 |
+
|
171 |
+
return x_hat, x_hat_tau, mean, log_var, z_tau, z_
|
172 |
+
|
173 |
+
def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var):
|
174 |
+
"""
|
175 |
+
Compute the VAE loss components.
|
176 |
+
:param x: Original input
|
177 |
+
:param x_tau: Future input (ground truth)
|
178 |
+
:param x_hat: Reconstructed x(t)
|
179 |
+
:param x_hat_tau: Predicted x(t+tau)
|
180 |
+
:param mean: Mean of the latent distribution
|
181 |
+
:param log_var: Log variance of the latent distribution
|
182 |
+
:return: reconstruction_loss1, reconstruction_loss2, KLD
|
183 |
+
"""
|
184 |
+
reconstruction_loss1 = nn.MSELoss()(x, x_hat) # Reconstruction loss for x(t)
|
185 |
+
reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau) # Prediction loss for x(t+tau)
|
186 |
+
|
187 |
+
# Kullback-Leibler Divergence
|
188 |
+
KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1)) # Updated dim
|
189 |
+
|
190 |
+
return reconstruction_loss1, reconstruction_loss2, KLD
|
model_io_adv_dif.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from data_adv_dif import IntervalSplit
|
4 |
+
from config_adv_dif import Config
|
5 |
+
|
6 |
+
|
7 |
+
def save_model(path, model, tau_interval_split, alpha_interval_split, config):
|
8 |
+
torch.save({
|
9 |
+
'model_state_dict': model.state_dict(),
|
10 |
+
'alpha_interval_split': asdict(alpha_interval_split),
|
11 |
+
'tau_interval_split': asdict(tau_interval_split),
|
12 |
+
'config': asdict(config),
|
13 |
+
}, path)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(path, model):
|
17 |
+
checkpoint = torch.load(path)
|
18 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
+
alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
|
20 |
+
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|
21 |
+
config = Config(**checkpoint['config'])
|
22 |
+
return model, alpha_interval_split, tau_interval_split, config
|
model_io_burgers.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from data_burgers import IntervalSplit
|
4 |
+
from config_burgers import Config
|
5 |
+
|
6 |
+
|
7 |
+
def save_model(path, model, tau_interval_split, re_interval_split, config):
|
8 |
+
torch.save({
|
9 |
+
'model_state_dict': model.state_dict(),
|
10 |
+
're_interval_split': asdict(re_interval_split),
|
11 |
+
'tau_interval_split': asdict(tau_interval_split),
|
12 |
+
'config': asdict(config),
|
13 |
+
}, path)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(path, model):
|
17 |
+
checkpoint = torch.load(path)
|
18 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
+
re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
|
20 |
+
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|
21 |
+
config = Config(**checkpoint['config'])
|
22 |
+
return model, re_interval_split, tau_interval_split, config
|
model_v2.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
# In[2]:
|
16 |
+
|
17 |
+
|
18 |
+
def positionalencoding1d(d_model, length):
|
19 |
+
"""
|
20 |
+
:param d_model: dimension of the model
|
21 |
+
:param length: length of positions
|
22 |
+
:return: length*d_model position matrix
|
23 |
+
"""
|
24 |
+
if d_model % 2 != 0:
|
25 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
26 |
+
"odd dim (got dim={:d})".format(d_model))
|
27 |
+
pe = torch.zeros(length, d_model)
|
28 |
+
position = torch.arange(0, length).unsqueeze(1)
|
29 |
+
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
30 |
+
-(math.log(10000.0) / d_model)))
|
31 |
+
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
32 |
+
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
33 |
+
|
34 |
+
return pe
|
35 |
+
|
36 |
+
|
37 |
+
# In[3]:
|
38 |
+
|
39 |
+
|
40 |
+
class Norm(nn.Module):
|
41 |
+
def __init__(self, num_channels, num_groups=4):
|
42 |
+
super(Norm, self).__init__()
|
43 |
+
self.norm = nn.GroupNorm(num_groups, num_channels)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return self.norm(x.permute(0,2,1)).permute(0,2,1)
|
47 |
+
|
48 |
+
|
49 |
+
# In[4]:
|
50 |
+
|
51 |
+
class Norm_new(nn.Module):
|
52 |
+
def __init__(self, num_channels, num_groups=4):
|
53 |
+
super(Norm_new, self).__init__()
|
54 |
+
self.norm = nn.GroupNorm(num_groups, num_channels)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if x.dim() == 2:
|
58 |
+
# Reshape to (batch_size, num_channels, 1)
|
59 |
+
x = x.unsqueeze(-1)
|
60 |
+
x = self.norm(x)
|
61 |
+
# Reshape back to (batch_size, num_channels)
|
62 |
+
x = x.squeeze(-1)
|
63 |
+
else:
|
64 |
+
x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class Encoder(nn.Module):
|
69 |
+
def __init__(self, input_dim, latent_dim=2, feats=[512, 256, 128, 64, 32]):
|
70 |
+
super(Encoder, self).__init__()
|
71 |
+
self.latent_dim = latent_dim
|
72 |
+
self._net = nn.Sequential(
|
73 |
+
nn.Linear(input_dim, feats[0]),
|
74 |
+
nn.GELU(),
|
75 |
+
Norm_new(feats[0]),
|
76 |
+
nn.Linear(feats[0], feats[1]),
|
77 |
+
nn.GELU(),
|
78 |
+
Norm_new(feats[1]),
|
79 |
+
nn.Linear(feats[1], feats[2]),
|
80 |
+
nn.GELU(),
|
81 |
+
Norm_new(feats[2]),
|
82 |
+
nn.Linear(feats[2], feats[3]),
|
83 |
+
nn.GELU(),
|
84 |
+
Norm_new(feats[3]),
|
85 |
+
nn.Linear(feats[3], feats[4]),
|
86 |
+
nn.GELU(),
|
87 |
+
Norm_new(feats[4]),
|
88 |
+
nn.Linear(feats[4], 2 * latent_dim)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
Z = self._net(x)
|
93 |
+
mean, log_var = torch.split(Z, self.latent_dim, dim=-1)
|
94 |
+
return mean, log_var
|
95 |
+
|
96 |
+
|
97 |
+
# In[5]:
|
98 |
+
|
99 |
+
|
100 |
+
class Decoder(nn.Module):
|
101 |
+
def __init__(self, latent_dim, output_dim, feats=[32, 64, 128, 256, 512]):
|
102 |
+
super(Decoder, self).__init__()
|
103 |
+
self.output_dim = output_dim
|
104 |
+
self._net = nn.Sequential(
|
105 |
+
nn.Linear(latent_dim, feats[0]),
|
106 |
+
nn.GELU(),
|
107 |
+
Norm_new(feats[0]),
|
108 |
+
nn.Linear(feats[0], feats[1]),
|
109 |
+
nn.GELU(),
|
110 |
+
Norm_new(feats[1]),
|
111 |
+
nn.Linear(feats[1], feats[2]),
|
112 |
+
nn.GELU(),
|
113 |
+
Norm_new(feats[2]),
|
114 |
+
nn.Linear(feats[2], feats[3]),
|
115 |
+
nn.GELU(),
|
116 |
+
Norm_new(feats[3]),
|
117 |
+
nn.Linear(feats[3], feats[4]),
|
118 |
+
nn.GELU(),
|
119 |
+
Norm_new(feats[4]),
|
120 |
+
nn.Linear(feats[4], output_dim),
|
121 |
+
nn.Tanh()
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
y = self._net(x)
|
126 |
+
return y
|
127 |
+
|
128 |
+
|
129 |
+
# In[6]:
|
130 |
+
|
131 |
+
|
132 |
+
class Propagator(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
|
133 |
+
def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64):
|
134 |
+
|
135 |
+
"""
|
136 |
+
Input : (z(t), tau)
|
137 |
+
Output: z(t+tau)
|
138 |
+
"""
|
139 |
+
self.max_tau = max_tau
|
140 |
+
super(Propagator, self).__init__()
|
141 |
+
self.register_buffer('encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64
|
142 |
+
|
143 |
+
self.projector = nn.Sequential(
|
144 |
+
nn.Linear(latent_dim, encoding_dim),
|
145 |
+
nn.ReLU(),
|
146 |
+
Norm(encoding_dim),
|
147 |
+
nn.Linear(encoding_dim, encoding_dim),
|
148 |
+
)
|
149 |
+
|
150 |
+
self._net = nn.Sequential(
|
151 |
+
nn.Linear(encoding_dim, feats[0]),
|
152 |
+
nn.ReLU(),
|
153 |
+
Norm(feats[0]),
|
154 |
+
nn.Linear(feats[0], feats[1]),
|
155 |
+
nn.ReLU(),
|
156 |
+
Norm(feats[1]),
|
157 |
+
nn.Linear(feats[1], latent_dim),
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def forward(self, z, tau):
|
162 |
+
zproj = self.projector(z)
|
163 |
+
enc = self.encodings[tau.long()]
|
164 |
+
# z: 2
|
165 |
+
# enc: 64
|
166 |
+
# [z1, z2, enc1, enc2, ..., enc64]
|
167 |
+
z = zproj + enc
|
168 |
+
|
169 |
+
z_tau = self._net(z)
|
170 |
+
return z_tau
|
171 |
+
|
172 |
+
|
173 |
+
# Doing this for the embedding for Re
|
174 |
+
class Propagator_encoding(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
|
175 |
+
def __init__(self, latent_dim, feats=[16, 32], max_tau=10000, encoding_dim=64, max_re = 5000):
|
176 |
+
|
177 |
+
"""
|
178 |
+
Input : (z(t), tau, re)
|
179 |
+
Output: z(t+tau)
|
180 |
+
"""
|
181 |
+
self.max_tau = max_tau
|
182 |
+
self.max_re = max_re
|
183 |
+
super(Propagator_encoding, self).__init__()
|
184 |
+
self.register_buffer('tau_encodings', positionalencoding1d(encoding_dim, max_tau)) # shape: max_tau, 64
|
185 |
+
self.register_buffer('re_encodings', positionalencoding1d(encoding_dim, max_re)) # shape: max_re, 64
|
186 |
+
|
187 |
+
self.projector = nn.Sequential(
|
188 |
+
nn.Linear(latent_dim, encoding_dim),
|
189 |
+
nn.ReLU(),
|
190 |
+
Norm(encoding_dim),
|
191 |
+
nn.Linear(encoding_dim, encoding_dim),
|
192 |
+
)
|
193 |
+
|
194 |
+
self._net = nn.Sequential(
|
195 |
+
nn.Linear(encoding_dim, feats[0]),
|
196 |
+
nn.ReLU(),
|
197 |
+
Norm(feats[0]),
|
198 |
+
nn.Linear(feats[0], feats[1]),
|
199 |
+
nn.ReLU(),
|
200 |
+
Norm(feats[1]),
|
201 |
+
nn.Linear(feats[1], latent_dim),
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
def forward(self, z, tau, re):
|
206 |
+
zproj = self.projector(z)
|
207 |
+
tau_enc = self.tau_encodings[tau.long()]
|
208 |
+
re_enc = self.re_encodings[re.long()]
|
209 |
+
# z: 2
|
210 |
+
# enc: 64
|
211 |
+
# [z1, z2, enc1, enc2, ..., enc64]
|
212 |
+
z = zproj + tau_enc + re_enc
|
213 |
+
#print("shape after enc addition: ", z.shape)
|
214 |
+
z_tau = self._net(z)
|
215 |
+
#print("shape z_tau: ", z_tau.shape)
|
216 |
+
return z_tau
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
class Propagator_concat(nn.Module): #taken in (z(t), tau) and outputs z(t+tau) [2, 5, 10, 2]
|
221 |
+
def __init__(self, latent_dim, feats = [16, 32]):
|
222 |
+
|
223 |
+
"""
|
224 |
+
Input : (z(t), tau, re)
|
225 |
+
Output: z(t+tau)
|
226 |
+
"""
|
227 |
+
super(Propagator_concat, self).__init__()
|
228 |
+
|
229 |
+
self._net = nn.Sequential(
|
230 |
+
nn.Linear(latent_dim + 2, feats[0]),
|
231 |
+
nn.ReLU(),
|
232 |
+
#Norm(feats[1]),
|
233 |
+
nn.Linear(feats[0], feats[1]),
|
234 |
+
nn.ReLU(),
|
235 |
+
#Norm(feats[2]),
|
236 |
+
nn.Linear(feats[1], latent_dim),
|
237 |
+
)
|
238 |
+
|
239 |
+
def forward(self, z, tau, re):
|
240 |
+
zproj = z.squeeze(1)
|
241 |
+
z_ = torch.cat((zproj, tau, re), dim = 1)
|
242 |
+
z_tau = self._net(z_)
|
243 |
+
z_tau = z_tau[:, None, :]
|
244 |
+
|
245 |
+
return z_tau
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
class Propagator_concat_one_step(nn.Module): #taken in (z(t), Re) and outputs z(t+tau) [2, 5, 10, 2]
|
250 |
+
def __init__(self, latent_dim, feats = [16, 32]):
|
251 |
+
|
252 |
+
"""
|
253 |
+
Input : (z(t), re)
|
254 |
+
Output: z(t+1*dt)
|
255 |
+
"""
|
256 |
+
super(Propagator_concat_one_step, self).__init__()
|
257 |
+
|
258 |
+
self._net = nn.Sequential(
|
259 |
+
nn.Linear(latent_dim + 1, feats[0]),
|
260 |
+
nn.ReLU(),
|
261 |
+
#Norm(feats[1]),
|
262 |
+
nn.Linear(feats[0], feats[1]),
|
263 |
+
nn.Tanh(),
|
264 |
+
#Norm(feats[2]),
|
265 |
+
nn.Linear(feats[1], latent_dim),
|
266 |
+
)
|
267 |
+
|
268 |
+
def forward(self, z, re):
|
269 |
+
#zproj = z.squeeze(1)
|
270 |
+
zproj = z
|
271 |
+
z_ = torch.cat((zproj, re), dim = 1)
|
272 |
+
z_tau = self._net(z_)
|
273 |
+
#z_tau = z_tau[:, None, :]
|
274 |
+
|
275 |
+
return z_tau
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
class Model(nn.Module):
|
280 |
+
def __init__(self, encoder, decoder, propagator):
|
281 |
+
super(Model, self).__init__()
|
282 |
+
self.encoder = encoder
|
283 |
+
self.decoder = decoder # decoder for x(t)
|
284 |
+
self.propagator = propagator # used to time march z(t) to z(t+tau)
|
285 |
+
|
286 |
+
def reparameterization(self, mean, var):
|
287 |
+
epsilon = torch.randn_like(var)
|
288 |
+
z = mean + var * epsilon
|
289 |
+
return z
|
290 |
+
|
291 |
+
def forward(self, x, tau, re):
|
292 |
+
mean, log_var = self.encoder(x)
|
293 |
+
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
|
294 |
+
|
295 |
+
# Update small fcnn to get z(t+tau) from z(t)
|
296 |
+
z_tau = self.propagator(z, tau, re)
|
297 |
+
|
298 |
+
# Reconstruction
|
299 |
+
x_hat = self.decoder(z) # Reconstruction of x(t)
|
300 |
+
x_hat_tau = self.decoder(z_tau)
|
301 |
+
|
302 |
+
return x_hat, x_hat_tau, mean, log_var, z_tau
|
303 |
+
|
304 |
+
|
305 |
+
class Model_One_Step(nn.Module): # Only takes in X and Re as the parameter and not the tau as tau = 1
|
306 |
+
def __init__(self, encoder, decoder, propagator):
|
307 |
+
super(Model_One_Step, self).__init__()
|
308 |
+
self.encoder = encoder
|
309 |
+
self.decoder = decoder # decoder for x(t)
|
310 |
+
self.propagator = propagator # used to time march z(t) to z(t+tau)
|
311 |
+
|
312 |
+
def reparameterization(self, mean, var):
|
313 |
+
epsilon = torch.randn_like(var)
|
314 |
+
z = mean + var * epsilon
|
315 |
+
return z
|
316 |
+
|
317 |
+
def forward(self, x, re):
|
318 |
+
mean, log_var = self.encoder(x)
|
319 |
+
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
|
320 |
+
|
321 |
+
# Update small fcnn to get z(t+1*dt) from z(t) -- We will use the Propagator_concat_one_step here!
|
322 |
+
z_tau = self.propagator(z, re)
|
323 |
+
|
324 |
+
# Reconstruction
|
325 |
+
x_hat = self.decoder(z) # Reconstruction of x(t)
|
326 |
+
x_hat_tau = self.decoder(z_tau)
|
327 |
+
|
328 |
+
return x_hat, x_hat_tau, mean, log_var, z_tau
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
class Model_reproduce(nn.Module):
|
333 |
+
def __init__(self, encoder, decoder):
|
334 |
+
super(Model_reproduce, self).__init__()
|
335 |
+
self.encoder = encoder
|
336 |
+
self.decoder = decoder # decoder for x(t)
|
337 |
+
|
338 |
+
def reparameterization(self, mean, var):
|
339 |
+
epsilon = torch.randn_like(var)
|
340 |
+
z = mean + var * epsilon
|
341 |
+
return z
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
mean, log_var = self.encoder(x)
|
345 |
+
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
|
346 |
+
|
347 |
+
# Reconstruction
|
348 |
+
x_hat = self.decoder(z) # Reconstruction of x(t)
|
349 |
+
|
350 |
+
return x_hat, mean, log_var
|
351 |
+
|
352 |
+
|
353 |
+
# Define loss function
|
354 |
+
def loss_function_reproduce(x, x_hat, mean, log_var):
|
355 |
+
reconstruction_loss1 = nn.MSELoss()(x, x_hat)
|
356 |
+
|
357 |
+
KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2))
|
358 |
+
return reconstruction_loss1, KLD
|
359 |
+
|
360 |
+
|
361 |
+
# Define loss function
|
362 |
+
def loss_function(x, x_tau, x_hat, x_hat_tau, mean, log_var):
|
363 |
+
reconstruction_loss1 = nn.MSELoss()(x, x_hat)
|
364 |
+
reconstruction_loss2 = nn.MSELoss()(x_tau, x_hat_tau)
|
365 |
+
|
366 |
+
KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=2))
|
367 |
+
return reconstruction_loss1, reconstruction_loss2, KLD
|
368 |
+
|
369 |
+
|
370 |
+
def count_parameters(model):
|
371 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
372 |
+
return total_params
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
gradio
|
4 |
+
matplotlib
|
5 |
+
dataclasses
|
6 |
+
json5
|
7 |
+
|