msong97 commited on
Commit
c082fb3
·
1 Parent(s): 499f595
Files changed (2) hide show
  1. .gitignore +5 -0
  2. factories.py +10 -16
.gitignore CHANGED
@@ -1,6 +1,11 @@
1
  # Mac
2
  .DS_Store
3
 
 
4
  .idea/
 
 
5
  .ipynb_checkpoints/
 
 
6
  __pycache__/
 
1
  # Mac
2
  .DS_Store
3
 
4
+ # PyCharm
5
  .idea/
6
+
7
+ # Jupyter notebooks
8
  .ipynb_checkpoints/
9
+
10
+ # Python
11
  __pycache__/
factories.py CHANGED
@@ -212,29 +212,24 @@ class EvalModel(torch.nn.Module):
212
  def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
213
  """Load the model we want to evaluate."""
214
  super().__init__()
215
- self.base_name = model_name
216
  self.ckpt_pth = ckpt_pth
217
- self.name = self.base_name
218
- if self.base_name not in self.all_models:
219
- raise ValueError(f"{self.base_name} is unavailable.")
220
- if self.base_name == "unext_emb_physics_config_C":
221
  if self.ckpt_pth == "":
222
- self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar"
223
- self.model = get_model(model_name=self.base_name,
224
  device='cpu',
225
  **DEFAULT_MODEL_PARAMS)
226
 
227
- # load model checkpoint
228
- state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[
229
- 'state_dict'] # load on cpu
230
  self.model.load_state_dict(state_dict)
231
  self.model.to(device_str)
232
  self.model.eval()
233
 
234
- # add epoch in the model name
235
- epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch']
236
- self.name = self.name + f"+{epoch}"
237
-
238
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
239
  return self.model(y, physics=physics)
240
 
@@ -250,9 +245,8 @@ class BaselineModel(torch.nn.Module):
250
 
251
  def __init__(self, model_name: str, device_str: str = "cpu") -> None:
252
  super().__init__()
253
- self.base_name = model_name
254
  self.ckpt_pth = ""
255
- self.name = self.base_name
256
  if self.name not in self.all_baselines:
257
  raise ValueError(f"{self.name} is unavailable.")
258
  elif self.name == "DPIR":
 
212
  def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
213
  """Load the model we want to evaluate."""
214
  super().__init__()
215
+ self.name = model_name
216
  self.ckpt_pth = ckpt_pth
217
+ if self.name not in self.all_models:
218
+ raise ValueError(f"{self.name} is unavailable.")
219
+ if self.name == "unext_emb_physics_config_C":
 
220
  if self.ckpt_pth == "":
221
+ self.ckpt_pth = "ckpt/ram.pth.tar"
222
+ self.model = get_model(model_name=self.name,
223
  device='cpu',
224
  **DEFAULT_MODEL_PARAMS)
225
 
226
+ # load model checkpoint on cpu
227
+ state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)
228
+
229
  self.model.load_state_dict(state_dict)
230
  self.model.to(device_str)
231
  self.model.eval()
232
 
 
 
 
 
233
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
234
  return self.model(y, physics=physics)
235
 
 
245
 
246
  def __init__(self, model_name: str, device_str: str = "cpu") -> None:
247
  super().__init__()
248
+ self.name = model_name
249
  self.ckpt_pth = ""
 
250
  if self.name not in self.all_baselines:
251
  raise ValueError(f"{self.name} is unavailable.")
252
  elif self.name == "DPIR":