Spaces:
Running
on
T4
Running
on
T4
update
Browse files- app.py +1 -1
- factories.py +5 -20
- model_factory.py +3 -2
app.py
CHANGED
@@ -172,7 +172,7 @@ def get_dataset(dataset_name):
|
|
172 |
|
173 |
|
174 |
# global variables shared by all users
|
175 |
-
ram_model = EvalModel(
|
176 |
ram_model.eval()
|
177 |
psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)
|
178 |
|
|
|
172 |
|
173 |
|
174 |
# global variables shared by all users
|
175 |
+
ram_model = EvalModel(device_str=DEVICE_STR)
|
176 |
ram_model.eval()
|
177 |
psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)
|
178 |
|
factories.py
CHANGED
@@ -170,32 +170,17 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
170 |
|
171 |
class EvalModel(torch.nn.Module):
|
172 |
"""Eval model.
|
173 |
-
|
174 |
-
Is there a difference with BaselineModel ?
|
175 |
-
-> BaselineModel should be models that are already trained and will have fixed weights.
|
176 |
-
-> Eval model will change depending on differents checkpoints.
|
177 |
"""
|
178 |
-
all_models = ["unext_emb_physics_config_C"]
|
179 |
|
180 |
-
def __init__(self,
|
181 |
"""Load the model we want to evaluate."""
|
182 |
super().__init__()
|
183 |
-
self.name =
|
184 |
-
self.
|
185 |
-
|
186 |
-
|
187 |
-
if self.name == "unext_emb_physics_config_C":
|
188 |
-
if self.ckpt_pth == "":
|
189 |
-
self.ckpt_pth = "ckpt/ram.pth.tar"
|
190 |
-
self.model = get_model()
|
191 |
-
self.model.to(device_str)
|
192 |
-
self.model.eval()
|
193 |
|
194 |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
195 |
-
physics.noise_model.sigma = torch.nn.Parameter(torch.tensor([1e-06]))
|
196 |
-
physics.noise_model.gain = torch.nn.Parameter(torch.tensor([1e-06]))
|
197 |
-
print('sigma = ', physics.noise_model.sigma)
|
198 |
-
print('gain = ', physics.noise_model.gain)
|
199 |
return self.model(y, physics=physics)
|
200 |
|
201 |
|
|
|
170 |
|
171 |
class EvalModel(torch.nn.Module):
|
172 |
"""Eval model.
|
|
|
|
|
|
|
|
|
173 |
"""
|
|
|
174 |
|
175 |
+
def __init__(self, device_str: str = "cpu") -> None:
|
176 |
"""Load the model we want to evaluate."""
|
177 |
super().__init__()
|
178 |
+
self.name = 'RAM'
|
179 |
+
self.model = get_model()
|
180 |
+
self.model.to(device_str)
|
181 |
+
self.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
184 |
return self.model(y, physics=physics)
|
185 |
|
186 |
|
model_factory.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import torch
|
2 |
from models.ram import RAM
|
3 |
|
|
|
|
|
4 |
def get_model():
|
5 |
"""
|
6 |
Load the model.
|
@@ -12,6 +14,5 @@ def get_model():
|
|
12 |
:return: model
|
13 |
"""
|
14 |
model = RAM()
|
15 |
-
|
16 |
-
model.load_state_dict(state_dict)
|
17 |
return model
|
|
|
1 |
import torch
|
2 |
from models.ram import RAM
|
3 |
|
4 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
5 |
+
|
6 |
def get_model():
|
7 |
"""
|
8 |
Load the model.
|
|
|
14 |
:return: model
|
15 |
"""
|
16 |
model = RAM()
|
17 |
+
model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
|
|
|
18 |
return model
|