ahsanMah commited on
Commit
f44174e
·
1 Parent(s): 69da633

dropping inference mode for now

Browse files
Files changed (2) hide show
  1. app.py +3 -1
  2. msma.py +2 -2
app.py CHANGED
@@ -42,9 +42,11 @@ def load_model_from_hub(preset, device):
42
  cache_dir="/tmp/",
43
  )
44
 
 
 
45
  model = ScoreFlow(scorenet, device=device, **model_params["PatchFlow"])
46
  model.load_state_dict(load_file(hf_checkpoint), strict=True)
47
-
48
  return model
49
 
50
 
 
42
  cache_dir="/tmp/",
43
  )
44
 
45
+ print("HF SAVE DIR:", hf_checkpoint)
46
+
47
  model = ScoreFlow(scorenet, device=device, **model_params["PatchFlow"])
48
  model.load_state_dict(load_file(hf_checkpoint), strict=True)
49
+ model = model.eval().requires_grad_(False)
50
  return model
51
 
52
 
msma.py CHANGED
@@ -81,7 +81,7 @@ class EDMScorer(torch.nn.Module):
81
 
82
  self.register_buffer("sigma_steps", t_steps.to(torch.float64))
83
 
84
- @torch.inference_mode()
85
  def forward(
86
  self,
87
  x,
@@ -378,7 +378,7 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
378
 
379
  with open(f"{experiment_dir}/config.json", "w") as f:
380
  json.dump(model.config, f, sort_keys=True, indent=4)
381
-
382
  # totaliters = int(epochs * train_len)
383
  pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
384
  step = 0
 
81
 
82
  self.register_buffer("sigma_steps", t_steps.to(torch.float64))
83
 
84
+ # @torch.inference_mode()
85
  def forward(
86
  self,
87
  x,
 
378
 
379
  with open(f"{experiment_dir}/config.json", "w") as f:
380
  json.dump(model.config, f, sort_keys=True, indent=4)
381
+
382
  # totaliters = int(epochs * train_len)
383
  pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
384
  step = 0