mterris commited on
Commit
81c09b8
·
1 Parent(s): fe1f918
Files changed (3) hide show
  1. app.py +1 -1
  2. factories.py +5 -20
  3. model_factory.py +3 -2
app.py CHANGED
@@ -172,7 +172,7 @@ def get_dataset(dataset_name):
172
 
173
 
174
  # global variables shared by all users
175
- ram_model = EvalModel("unext_emb_physics_config_C", device_str=DEVICE_STR)
176
  ram_model.eval()
177
  psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)
178
 
 
172
 
173
 
174
  # global variables shared by all users
175
+ ram_model = EvalModel(device_str=DEVICE_STR)
176
  ram_model.eval()
177
  psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)
178
 
factories.py CHANGED
@@ -170,32 +170,17 @@ class PhysicsWithGenerator(torch.nn.Module):
170
 
171
  class EvalModel(torch.nn.Module):
172
  """Eval model.
173
-
174
- Is there a difference with BaselineModel ?
175
- -> BaselineModel should be models that are already trained and will have fixed weights.
176
- -> Eval model will change depending on differents checkpoints.
177
  """
178
- all_models = ["unext_emb_physics_config_C"]
179
 
180
- def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
181
  """Load the model we want to evaluate."""
182
  super().__init__()
183
- self.name = model_name
184
- self.ckpt_pth = ckpt_pth
185
- if self.name not in self.all_models:
186
- raise ValueError(f"{self.name} is unavailable.")
187
- if self.name == "unext_emb_physics_config_C":
188
- if self.ckpt_pth == "":
189
- self.ckpt_pth = "ckpt/ram.pth.tar"
190
- self.model = get_model()
191
- self.model.to(device_str)
192
- self.model.eval()
193
 
194
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
195
- physics.noise_model.sigma = torch.nn.Parameter(torch.tensor([1e-06]))
196
- physics.noise_model.gain = torch.nn.Parameter(torch.tensor([1e-06]))
197
- print('sigma = ', physics.noise_model.sigma)
198
- print('gain = ', physics.noise_model.gain)
199
  return self.model(y, physics=physics)
200
 
201
 
 
170
 
171
  class EvalModel(torch.nn.Module):
172
  """Eval model.
 
 
 
 
173
  """
 
174
 
175
+ def __init__(self, device_str: str = "cpu") -> None:
176
  """Load the model we want to evaluate."""
177
  super().__init__()
178
+ self.name = 'RAM'
179
+ self.model = get_model()
180
+ self.model.to(device_str)
181
+ self.model.eval()
 
 
 
 
 
 
182
 
183
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
 
 
 
 
184
  return self.model(y, physics=physics)
185
 
186
 
model_factory.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  from models.ram import RAM
3
 
 
 
4
  def get_model():
5
  """
6
  Load the model.
@@ -12,6 +14,5 @@ def get_model():
12
  :return: model
13
  """
14
  model = RAM()
15
- state_dict = torch.load('ckpt/ram.pth.tar')
16
- model.load_state_dict(state_dict)
17
  return model
 
1
  import torch
2
  from models.ram import RAM
3
 
4
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
5
+
6
  def get_model():
7
  """
8
  Load the model.
 
14
  :return: model
15
  """
16
  model = RAM()
17
+ model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
 
18
  return model