Spaces:
Sleeping
Sleeping
update
Browse files- factories.py +1 -32
factories.py
CHANGED
@@ -10,30 +10,6 @@ from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDatas
|
|
10 |
from model_factory import get_model
|
11 |
from physics.blur_generator import GaussianBlurGenerator
|
12 |
|
13 |
-
DEFAULT_MODEL_PARAMS = {
|
14 |
-
"in_channels": [1, 2, 3],
|
15 |
-
"grayscale": False,
|
16 |
-
"conv_type": "base",
|
17 |
-
"pool_type": "base",
|
18 |
-
"layer_scale_init_value": 1e-6,
|
19 |
-
"init_type": "ortho",
|
20 |
-
"gain_init_conv": 1.0,
|
21 |
-
"gain_init_linear": 1.0,
|
22 |
-
"drop_prob": 0.0,
|
23 |
-
"replk": False,
|
24 |
-
"mult_fact": 4,
|
25 |
-
"antialias": "gaussian",
|
26 |
-
"nc_base": 64,
|
27 |
-
"cond_type": "base",
|
28 |
-
"blind": False,
|
29 |
-
"pretrained_pth": None,
|
30 |
-
"N": 2,
|
31 |
-
"c_mult": 2,
|
32 |
-
"depth_encoding": 2,
|
33 |
-
"relu_in_encoding": False,
|
34 |
-
"skip_in_encoding": True
|
35 |
-
}
|
36 |
-
|
37 |
|
38 |
class PhysicsWithGenerator(torch.nn.Module):
|
39 |
"""Interface between Physics, Generator and Gradio."""
|
@@ -211,14 +187,7 @@ class EvalModel(torch.nn.Module):
|
|
211 |
if self.name == "unext_emb_physics_config_C":
|
212 |
if self.ckpt_pth == "":
|
213 |
self.ckpt_pth = "ckpt/ram.pth.tar"
|
214 |
-
self.model = get_model(
|
215 |
-
device='cpu',
|
216 |
-
**DEFAULT_MODEL_PARAMS)
|
217 |
-
|
218 |
-
# load model checkpoint on cpu
|
219 |
-
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)
|
220 |
-
|
221 |
-
self.model.load_state_dict(state_dict)
|
222 |
self.model.to(device_str)
|
223 |
self.model.eval()
|
224 |
|
|
|
10 |
from model_factory import get_model
|
11 |
from physics.blur_generator import GaussianBlurGenerator
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class PhysicsWithGenerator(torch.nn.Module):
|
15 |
"""Interface between Physics, Generator and Gradio."""
|
|
|
187 |
if self.name == "unext_emb_physics_config_C":
|
188 |
if self.ckpt_pth == "":
|
189 |
self.ckpt_pth = "ckpt/ram.pth.tar"
|
190 |
+
self.model = get_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
self.model.to(device_str)
|
192 |
self.model.eval()
|
193 |
|