Spaces:
Sleeping
Sleeping
new ckpt
Browse files- .gitignore +5 -0
- factories.py +10 -16
.gitignore
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
# Mac
|
2 |
.DS_Store
|
3 |
|
|
|
4 |
.idea/
|
|
|
|
|
5 |
.ipynb_checkpoints/
|
|
|
|
|
6 |
__pycache__/
|
|
|
1 |
# Mac
|
2 |
.DS_Store
|
3 |
|
4 |
+
# PyCharm
|
5 |
.idea/
|
6 |
+
|
7 |
+
# Jupyter notebooks
|
8 |
.ipynb_checkpoints/
|
9 |
+
|
10 |
+
# Python
|
11 |
__pycache__/
|
factories.py
CHANGED
@@ -212,29 +212,24 @@ class EvalModel(torch.nn.Module):
|
|
212 |
def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
|
213 |
"""Load the model we want to evaluate."""
|
214 |
super().__init__()
|
215 |
-
self.
|
216 |
self.ckpt_pth = ckpt_pth
|
217 |
-
self.name
|
218 |
-
|
219 |
-
|
220 |
-
if self.base_name == "unext_emb_physics_config_C":
|
221 |
if self.ckpt_pth == "":
|
222 |
-
self.ckpt_pth = "ckpt/
|
223 |
-
self.model = get_model(model_name=self.
|
224 |
device='cpu',
|
225 |
**DEFAULT_MODEL_PARAMS)
|
226 |
|
227 |
-
# load model checkpoint
|
228 |
-
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)
|
229 |
-
|
230 |
self.model.load_state_dict(state_dict)
|
231 |
self.model.to(device_str)
|
232 |
self.model.eval()
|
233 |
|
234 |
-
# add epoch in the model name
|
235 |
-
epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch']
|
236 |
-
self.name = self.name + f"+{epoch}"
|
237 |
-
|
238 |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
239 |
return self.model(y, physics=physics)
|
240 |
|
@@ -250,9 +245,8 @@ class BaselineModel(torch.nn.Module):
|
|
250 |
|
251 |
def __init__(self, model_name: str, device_str: str = "cpu") -> None:
|
252 |
super().__init__()
|
253 |
-
self.
|
254 |
self.ckpt_pth = ""
|
255 |
-
self.name = self.base_name
|
256 |
if self.name not in self.all_baselines:
|
257 |
raise ValueError(f"{self.name} is unavailable.")
|
258 |
elif self.name == "DPIR":
|
|
|
212 |
def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
|
213 |
"""Load the model we want to evaluate."""
|
214 |
super().__init__()
|
215 |
+
self.name = model_name
|
216 |
self.ckpt_pth = ckpt_pth
|
217 |
+
if self.name not in self.all_models:
|
218 |
+
raise ValueError(f"{self.name} is unavailable.")
|
219 |
+
if self.name == "unext_emb_physics_config_C":
|
|
|
220 |
if self.ckpt_pth == "":
|
221 |
+
self.ckpt_pth = "ckpt/ram.pth.tar"
|
222 |
+
self.model = get_model(model_name=self.name,
|
223 |
device='cpu',
|
224 |
**DEFAULT_MODEL_PARAMS)
|
225 |
|
226 |
+
# load model checkpoint on cpu
|
227 |
+
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)
|
228 |
+
|
229 |
self.model.load_state_dict(state_dict)
|
230 |
self.model.to(device_str)
|
231 |
self.model.eval()
|
232 |
|
|
|
|
|
|
|
|
|
233 |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
234 |
return self.model(y, physics=physics)
|
235 |
|
|
|
245 |
|
246 |
def __init__(self, model_name: str, device_str: str = "cpu") -> None:
|
247 |
super().__init__()
|
248 |
+
self.name = model_name
|
249 |
self.ckpt_pth = ""
|
|
|
250 |
if self.name not in self.all_baselines:
|
251 |
raise ValueError(f"{self.name} is unavailable.")
|
252 |
elif self.name == "DPIR":
|