Spaces:
Sleeping
Sleeping
Khalid Rafiq
commited on
Commit
·
1c9f376
1
Parent(s):
e94fa7b
Fix: Map Advection-Diffusion model checkpoint to CPU
Browse files
.ipynb_checkpoints/model_io_adv_dif-checkpoint.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from data_adv_dif import IntervalSplit
|
4 |
+
from config_adv_dif import Config
|
5 |
+
|
6 |
+
|
7 |
+
def save_model(path, model, tau_interval_split, alpha_interval_split, config):
|
8 |
+
torch.save({
|
9 |
+
'model_state_dict': model.state_dict(),
|
10 |
+
'alpha_interval_split': asdict(alpha_interval_split),
|
11 |
+
'tau_interval_split': asdict(tau_interval_split),
|
12 |
+
'config': asdict(config),
|
13 |
+
}, path)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(path, model):
|
17 |
+
checkpoint = torch.load(path, map_location='cpu')
|
18 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
+
alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
|
20 |
+
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|
21 |
+
config = Config(**checkpoint['config'])
|
22 |
+
return model, alpha_interval_split, tau_interval_split, config
|
.ipynb_checkpoints/model_io_burgers-checkpoint.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from data_burgers import IntervalSplit
|
4 |
+
from config_burgers import Config
|
5 |
+
|
6 |
+
|
7 |
+
def save_model(path, model, tau_interval_split, re_interval_split, config):
|
8 |
+
torch.save({
|
9 |
+
'model_state_dict': model.state_dict(),
|
10 |
+
're_interval_split': asdict(re_interval_split),
|
11 |
+
'tau_interval_split': asdict(tau_interval_split),
|
12 |
+
'config': asdict(config),
|
13 |
+
}, path)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(path, model):
|
17 |
+
checkpoint = torch.load(path, map_location='cpu')
|
18 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
+
re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
|
20 |
+
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|
21 |
+
config = Config(**checkpoint['config'])
|
22 |
+
return model, re_interval_split, tau_interval_split, config
|
__pycache__/model_io_adv_dif.cpython-310.pyc
CHANGED
Binary files a/__pycache__/model_io_adv_dif.cpython-310.pyc and b/__pycache__/model_io_adv_dif.cpython-310.pyc differ
|
|
__pycache__/model_io_burgers.cpython-310.pyc
CHANGED
Binary files a/__pycache__/model_io_burgers.cpython-310.pyc and b/__pycache__/model_io_burgers.cpython-310.pyc differ
|
|
app.ipynb
CHANGED
@@ -10,7 +10,7 @@
|
|
10 |
"name": "stdout",
|
11 |
"output_type": "stream",
|
12 |
"text": [
|
13 |
-
"* Running on local URL: http://127.0.0.1:
|
14 |
"\n",
|
15 |
"To create a public link, set `share=True` in `launch()`.\n"
|
16 |
]
|
@@ -18,7 +18,7 @@
|
|
18 |
{
|
19 |
"data": {
|
20 |
"text/html": [
|
21 |
-
"<div><iframe src=\"http://127.0.0.1:
|
22 |
],
|
23 |
"text/plain": [
|
24 |
"<IPython.core.display.HTML object>"
|
|
|
10 |
"name": "stdout",
|
11 |
"output_type": "stream",
|
12 |
"text": [
|
13 |
+
"* Running on local URL: http://127.0.0.1:7886\n",
|
14 |
"\n",
|
15 |
"To create a public link, set `share=True` in `launch()`.\n"
|
16 |
]
|
|
|
18 |
{
|
19 |
"data": {
|
20 |
"text/html": [
|
21 |
+
"<div><iframe src=\"http://127.0.0.1:7886/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
22 |
],
|
23 |
"text/plain": [
|
24 |
"<IPython.core.display.HTML object>"
|
model_io_adv_dif.py
CHANGED
@@ -14,7 +14,7 @@ def save_model(path, model, tau_interval_split, alpha_interval_split, config):
|
|
14 |
|
15 |
|
16 |
def load_model(path, model):
|
17 |
-
checkpoint = torch.load(path)
|
18 |
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
|
20 |
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|
|
|
14 |
|
15 |
|
16 |
def load_model(path, model):
|
17 |
+
checkpoint = torch.load(path, map_location='cpu')
|
18 |
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
|
20 |
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|
model_io_burgers.py
CHANGED
@@ -14,7 +14,7 @@ def save_model(path, model, tau_interval_split, re_interval_split, config):
|
|
14 |
|
15 |
|
16 |
def load_model(path, model):
|
17 |
-
checkpoint = torch.load(path)
|
18 |
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
|
20 |
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|
|
|
14 |
|
15 |
|
16 |
def load_model(path, model):
|
17 |
+
checkpoint = torch.load(path, map_location='cpu')
|
18 |
model.load_state_dict(checkpoint['model_state_dict'])
|
19 |
re_interval_split = IntervalSplit(**checkpoint['re_interval_split'])
|
20 |
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
|