ahsanMah commited on
Commit
a202591
·
1 Parent(s): 2f9f504

usng config to build local model

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -16,10 +16,14 @@ from msma import ScoreFlow, build_model_from_pickle, config_presets
16
 
17
  @cache
18
  def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu"):
 
 
 
19
  scorenet = build_model_from_pickle(preset=preset)
20
- model = ScoreFlow(scorenet, num_flows=8, device=device)
21
- model.flow.load_state_dict(torch.load(f"{modeldir}/{preset}/flow.pt"))
22
- return model
 
23
 
24
 
25
  @cache
@@ -113,8 +117,8 @@ def localize_anomalies(input_img, preset="edm2-img64-s-fid", load_from_hub=False
113
  if load_from_hub:
114
  model, modeldir = load_model_from_hub(preset=preset, device=device)
115
  else:
116
- modeldir = f"models/{preset}"
117
  model = load_model(modeldir="models", preset=preset, device=device)
 
118
  img_likelihood, score_norms = run_inference(model, img)
119
  nll, pct, ref_nll = compute_gmm_likelihood(
120
  score_norms, model_dir=modeldir
 
16
 
17
  @cache
18
  def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu"):
19
+ modeldir = f"{modeldir}/{preset}"
20
+ with open(f"{modeldir}/config.json", "rb") as f:
21
+ model_params = json.load(f)
22
  scorenet = build_model_from_pickle(preset=preset)
23
+ model = ScoreFlow(scorenet, **model_params['PatchFlow'])
24
+ model.flow.load_state_dict(torch.load(f"{modeldir}/flow.pt"))
25
+ print("Loaded:", model_params)
26
+ return model.to(device)
27
 
28
 
29
  @cache
 
117
  if load_from_hub:
118
  model, modeldir = load_model_from_hub(preset=preset, device=device)
119
  else:
 
120
  model = load_model(modeldir="models", preset=preset, device=device)
121
+ modeldir = f"models/{preset}"
122
  img_likelihood, score_norms = run_inference(model, img)
123
  nll, pct, ref_nll = compute_gmm_likelihood(
124
  score_norms, model_dir=modeldir