Khalid Rafiq commited on
Commit
1c9f376
·
1 Parent(s): e94fa7b

Fix: Map Advection-Diffusion model checkpoint to CPU

Browse files
.ipynb_checkpoints/model_io_adv_dif-checkpoint.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, asdict
3
+ from data_adv_dif import IntervalSplit
4
+ from config_adv_dif import Config
5
+
6
+
7
+ def save_model(path, model, tau_interval_split, alpha_interval_split, config):
8
+ torch.save({
9
+ 'model_state_dict': model.state_dict(),
10
+ 'alpha_interval_split': asdict(alpha_interval_split),
11
+ 'tau_interval_split': asdict(tau_interval_split),
12
+ 'config': asdict(config),
13
+ }, path)
14
+
15
+
16
+ def load_model(path, model):
17
+ checkpoint = torch.load(path, map_location='cpu')
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
+ alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
20
+ tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
21
+ config = Config(**checkpoint['config'])
22
+ return model, alpha_interval_split, tau_interval_split, config
.ipynb_checkpoints/model_io_burgers-checkpoint.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, asdict
3
+ from data_burgers import IntervalSplit
4
+ from config_burgers import Config
5
+
6
+
7
+ def save_model(path, model, tau_interval_split, re_interval_split, config):
8
+ torch.save({
9
+ 'model_state_dict': model.state_dict(),
10
+ 're_interval_split': asdict(re_interval_split),
11
+ 'tau_interval_split': asdict(tau_interval_split),
12
+ 'config': asdict(config),
13
+ }, path)
14
+
15
+
16
+ def load_model(path, model):
17
+ checkpoint = torch.load(path, map_location='cpu')
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
+ re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
20
+ tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
21
+ config = Config(**checkpoint['config'])
22
+ return model, re_interval_split, tau_interval_split, config
__pycache__/model_io_adv_dif.cpython-310.pyc CHANGED
Binary files a/__pycache__/model_io_adv_dif.cpython-310.pyc and b/__pycache__/model_io_adv_dif.cpython-310.pyc differ
 
__pycache__/model_io_burgers.cpython-310.pyc CHANGED
Binary files a/__pycache__/model_io_burgers.cpython-310.pyc and b/__pycache__/model_io_burgers.cpython-310.pyc differ
 
app.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  "name": "stdout",
11
  "output_type": "stream",
12
  "text": [
13
- "* Running on local URL: http://127.0.0.1:7885\n",
14
  "\n",
15
  "To create a public link, set `share=True` in `launch()`.\n"
16
  ]
@@ -18,7 +18,7 @@
18
  {
19
  "data": {
20
  "text/html": [
21
- "<div><iframe src=\"http://127.0.0.1:7885/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
  ],
23
  "text/plain": [
24
  "<IPython.core.display.HTML object>"
 
10
  "name": "stdout",
11
  "output_type": "stream",
12
  "text": [
13
+ "* Running on local URL: http://127.0.0.1:7886\n",
14
  "\n",
15
  "To create a public link, set `share=True` in `launch()`.\n"
16
  ]
 
18
  {
19
  "data": {
20
  "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7886/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
  ],
23
  "text/plain": [
24
  "<IPython.core.display.HTML object>"
model_io_adv_dif.py CHANGED
@@ -14,7 +14,7 @@ def save_model(path, model, tau_interval_split, alpha_interval_split, config):
14
 
15
 
16
  def load_model(path, model):
17
- checkpoint = torch.load(path)
18
  model.load_state_dict(checkpoint['model_state_dict'])
19
  alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
20
  tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
 
14
 
15
 
16
  def load_model(path, model):
17
+ checkpoint = torch.load(path, map_location='cpu')
18
  model.load_state_dict(checkpoint['model_state_dict'])
19
  alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
20
  tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
model_io_burgers.py CHANGED
@@ -14,7 +14,7 @@ def save_model(path, model, tau_interval_split, re_interval_split, config):
14
 
15
 
16
  def load_model(path, model):
17
- checkpoint = torch.load(path)
18
  model.load_state_dict(checkpoint['model_state_dict'])
19
  re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
20
  tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
 
14
 
15
 
16
  def load_model(path, model):
17
+ checkpoint = torch.load(path, map_location='cpu')
18
  model.load_state_dict(checkpoint['model_state_dict'])
19
  re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
20
  tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])