ahsanMah commited on
Commit
95e0f92
·
1 Parent(s): d71875f

minor fixes to train flow runner

Browse files
Files changed (1) hide show
  1. msma.py +44 -16
msma.py CHANGED
@@ -189,7 +189,9 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
189
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
190
  refimg, reflabel = dsobj[0]
191
  print(f"Loading dataset from {dataset_path}")
192
- print(f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}")
 
 
193
  dsloader = torch.utils.data.DataLoader(
194
  dsobj, batch_size=48, num_workers=4, prefetch_factor=2
195
  )
@@ -211,7 +213,7 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
211
  print(f"Computed score norms for {score_norms.shape[0]} samples")
212
 
213
 
214
- def train_flow(dataset_path, preset, device="cuda"):
215
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
216
  refimg, reflabel = dsobj[0]
217
  print(f"Loaded {len(dsobj)} samples from {dataset_path}")
@@ -252,6 +254,7 @@ def train_flow(dataset_path, preset, device="cuda"):
252
  device=device,
253
  )
254
 
 
255
  pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
256
  step = 0
257
 
@@ -280,8 +283,8 @@ def train_flow(dataset_path, preset, device="cuda"):
280
  f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
281
  )
282
  step += 1
283
-
284
- torch.save(model.flow.state_dict(), f"out/msma/{preset}/flow.pt")
285
 
286
 
287
  @torch.inference_mode
@@ -327,22 +330,43 @@ def test_flow_runner(preset, device="cpu", load_weights=None):
327
  @click.command()
328
 
329
  # Main options.
330
- @click.option('--run', help='Which function to run',
331
- type=click.Choice(['cache-scores', 'train-flow', 'train-gmm'], case_sensitive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  )
333
- @click.option('--outdir', help='Where to load/save the results', metavar='DIR', type=str, required=True)
334
- @click.option('--preset', help='Configuration preset', metavar='STR', type=str, default='edm2-img64-s-fid', show_default=True)
335
- @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, default=None)
336
  def cmdline(run, outdir, **opts):
337
  device = "cuda" if torch.cuda.is_available() else "cpu"
338
- preset = opts['preset']
339
- dataset_path = opts['data']
340
-
341
- if run in ['cache-scores', 'train-flow']:
342
- assert opts['data'] is not None, "Provide path to dataset"
343
 
 
 
 
344
  if run == "cache-scores":
345
- cache_score_norms(preset=preset, dataset_path=dataset_path, outdir=outdir, device=device)
 
 
346
 
347
  if run == "train-gmm":
348
  train_gmm(
@@ -350,8 +374,11 @@ def cmdline(run, outdir, **opts):
350
  outdir=f"{outdir}/{preset}",
351
  grid_search=True,
352
  )
 
 
 
 
353
 
354
- # test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
355
  # train_flow(imagenette_path, preset, device)
356
 
357
  # cache_score_norms(
@@ -368,5 +395,6 @@ def cmdline(run, outdir, **opts):
368
  # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
369
  # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
370
 
 
371
  if __name__ == "__main__":
372
  cmdline()
 
189
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
190
  refimg, reflabel = dsobj[0]
191
  print(f"Loading dataset from {dataset_path}")
192
+ print(
193
+ f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
194
+ )
195
  dsloader = torch.utils.data.DataLoader(
196
  dsobj, batch_size=48, num_workers=4, prefetch_factor=2
197
  )
 
213
  print(f"Computed score norms for {score_norms.shape[0]} samples")
214
 
215
 
216
+ def train_flow(dataset_path, preset, outdir, device="cuda"):
217
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
218
  refimg, reflabel = dsobj[0]
219
  print(f"Loaded {len(dsobj)} samples from {dataset_path}")
 
254
  device=device,
255
  )
256
 
257
+ os.makedirs(f"{outdir}/{preset}", exist_ok=True)
258
  pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
259
  step = 0
260
 
 
283
  f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
284
  )
285
  step += 1
286
+
287
+ torch.save(model.flow.state_dict(), f"{outdir}/{preset}/flow.pt")
288
 
289
 
290
  @torch.inference_mode
 
330
  @click.command()
331
 
332
  # Main options.
333
+ @click.option(
334
+ "--run",
335
+ help="Which function to run",
336
+ type=click.Choice(
337
+ ["cache-scores", "train-flow", "train-gmm"], case_sensitive=False
338
+ ),
339
+ )
340
+ @click.option(
341
+ "--outdir",
342
+ help="Where to load/save the results",
343
+ metavar="DIR",
344
+ type=str,
345
+ required=True,
346
+ )
347
+ @click.option(
348
+ "--preset",
349
+ help="Configuration preset",
350
+ metavar="STR",
351
+ type=str,
352
+ default="edm2-img64-s-fid",
353
+ show_default=True,
354
+ )
355
+ @click.option(
356
+ "--data", help="Path to the dataset", metavar="ZIP|DIR", type=str, default=None
357
  )
 
 
 
358
  def cmdline(run, outdir, **opts):
359
  device = "cuda" if torch.cuda.is_available() else "cpu"
360
+ preset = opts["preset"]
361
+ dataset_path = opts["data"]
 
 
 
362
 
363
+ if run in ["cache-scores", "train-flow"]:
364
+ assert opts["data"] is not None, "Provide path to dataset"
365
+
366
  if run == "cache-scores":
367
+ cache_score_norms(
368
+ preset=preset, dataset_path=dataset_path, outdir=outdir, device=device
369
+ )
370
 
371
  if run == "train-gmm":
372
  train_gmm(
 
374
  outdir=f"{outdir}/{preset}",
375
  grid_search=True,
376
  )
377
+
378
+ if run == "train-flow":
379
+ train_flow(dataset_path, outdir=outdir, preset=preset, device=device)
380
+ test_flow_runner(preset, device=device, load_weights=f"{outdir}/{preset}/flow.pt")
381
 
 
382
  # train_flow(imagenette_path, preset, device)
383
 
384
  # cache_score_norms(
 
395
  # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
396
  # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
397
 
398
+
399
  if __name__ == "__main__":
400
  cmdline()