mterris commited on
Commit
fe1f918
·
1 Parent(s): ed95f9b
Files changed (1) hide show
  1. factories.py +1 -32
factories.py CHANGED
@@ -10,30 +10,6 @@ from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDatas
10
  from model_factory import get_model
11
  from physics.blur_generator import GaussianBlurGenerator
12
 
13
- DEFAULT_MODEL_PARAMS = {
14
- "in_channels": [1, 2, 3],
15
- "grayscale": False,
16
- "conv_type": "base",
17
- "pool_type": "base",
18
- "layer_scale_init_value": 1e-6,
19
- "init_type": "ortho",
20
- "gain_init_conv": 1.0,
21
- "gain_init_linear": 1.0,
22
- "drop_prob": 0.0,
23
- "replk": False,
24
- "mult_fact": 4,
25
- "antialias": "gaussian",
26
- "nc_base": 64,
27
- "cond_type": "base",
28
- "blind": False,
29
- "pretrained_pth": None,
30
- "N": 2,
31
- "c_mult": 2,
32
- "depth_encoding": 2,
33
- "relu_in_encoding": False,
34
- "skip_in_encoding": True
35
- }
36
-
37
 
38
  class PhysicsWithGenerator(torch.nn.Module):
39
  """Interface between Physics, Generator and Gradio."""
@@ -211,14 +187,7 @@ class EvalModel(torch.nn.Module):
211
  if self.name == "unext_emb_physics_config_C":
212
  if self.ckpt_pth == "":
213
  self.ckpt_pth = "ckpt/ram.pth.tar"
214
- self.model = get_model(model_name=self.name,
215
- device='cpu',
216
- **DEFAULT_MODEL_PARAMS)
217
-
218
- # load model checkpoint on cpu
219
- state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)
220
-
221
- self.model.load_state_dict(state_dict)
222
  self.model.to(device_str)
223
  self.model.eval()
224
 
 
10
  from model_factory import get_model
11
  from physics.blur_generator import GaussianBlurGenerator
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class PhysicsWithGenerator(torch.nn.Module):
15
  """Interface between Physics, Generator and Gradio."""
 
187
  if self.name == "unext_emb_physics_config_C":
188
  if self.ckpt_pth == "":
189
  self.ckpt_pth = "ckpt/ram.pth.tar"
190
+ self.model = get_model()
 
 
 
 
 
 
 
191
  self.model.to(device_str)
192
  self.model.eval()
193