Spaces:
Running
on
Zero
Running
on
Zero
minor fixes to train flow runner
Browse files
msma.py
CHANGED
@@ -189,7 +189,9 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
|
|
189 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
190 |
refimg, reflabel = dsobj[0]
|
191 |
print(f"Loading dataset from {dataset_path}")
|
192 |
-
print(
|
|
|
|
|
193 |
dsloader = torch.utils.data.DataLoader(
|
194 |
dsobj, batch_size=48, num_workers=4, prefetch_factor=2
|
195 |
)
|
@@ -211,7 +213,7 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
|
|
211 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
212 |
|
213 |
|
214 |
-
def train_flow(dataset_path, preset, device="cuda"):
|
215 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
216 |
refimg, reflabel = dsobj[0]
|
217 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
@@ -252,6 +254,7 @@ def train_flow(dataset_path, preset, device="cuda"):
|
|
252 |
device=device,
|
253 |
)
|
254 |
|
|
|
255 |
pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
|
256 |
step = 0
|
257 |
|
@@ -280,8 +283,8 @@ def train_flow(dataset_path, preset, device="cuda"):
|
|
280 |
f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
|
281 |
)
|
282 |
step += 1
|
283 |
-
|
284 |
-
torch.save(model.flow.state_dict(), f"
|
285 |
|
286 |
|
287 |
@torch.inference_mode
|
@@ -327,22 +330,43 @@ def test_flow_runner(preset, device="cpu", load_weights=None):
|
|
327 |
@click.command()
|
328 |
|
329 |
# Main options.
|
330 |
-
@click.option(
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
)
|
333 |
-
@click.option('--outdir', help='Where to load/save the results', metavar='DIR', type=str, required=True)
|
334 |
-
@click.option('--preset', help='Configuration preset', metavar='STR', type=str, default='edm2-img64-s-fid', show_default=True)
|
335 |
-
@click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, default=None)
|
336 |
def cmdline(run, outdir, **opts):
|
337 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
338 |
-
preset = opts[
|
339 |
-
dataset_path = opts[
|
340 |
-
|
341 |
-
if run in ['cache-scores', 'train-flow']:
|
342 |
-
assert opts['data'] is not None, "Provide path to dataset"
|
343 |
|
|
|
|
|
|
|
344 |
if run == "cache-scores":
|
345 |
-
cache_score_norms(
|
|
|
|
|
346 |
|
347 |
if run == "train-gmm":
|
348 |
train_gmm(
|
@@ -350,8 +374,11 @@ def cmdline(run, outdir, **opts):
|
|
350 |
outdir=f"{outdir}/{preset}",
|
351 |
grid_search=True,
|
352 |
)
|
|
|
|
|
|
|
|
|
353 |
|
354 |
-
# test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
|
355 |
# train_flow(imagenette_path, preset, device)
|
356 |
|
357 |
# cache_score_norms(
|
@@ -368,5 +395,6 @@ def cmdline(run, outdir, **opts):
|
|
368 |
# nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
|
369 |
# print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|
370 |
|
|
|
371 |
if __name__ == "__main__":
|
372 |
cmdline()
|
|
|
189 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
190 |
refimg, reflabel = dsobj[0]
|
191 |
print(f"Loading dataset from {dataset_path}")
|
192 |
+
print(
|
193 |
+
f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
|
194 |
+
)
|
195 |
dsloader = torch.utils.data.DataLoader(
|
196 |
dsobj, batch_size=48, num_workers=4, prefetch_factor=2
|
197 |
)
|
|
|
213 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
214 |
|
215 |
|
216 |
+
def train_flow(dataset_path, preset, outdir, device="cuda"):
|
217 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
218 |
refimg, reflabel = dsobj[0]
|
219 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
|
|
254 |
device=device,
|
255 |
)
|
256 |
|
257 |
+
os.makedirs(f"{outdir}/{preset}", exist_ok=True)
|
258 |
pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
|
259 |
step = 0
|
260 |
|
|
|
283 |
f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
|
284 |
)
|
285 |
step += 1
|
286 |
+
|
287 |
+
torch.save(model.flow.state_dict(), f"{outdir}/{preset}/flow.pt")
|
288 |
|
289 |
|
290 |
@torch.inference_mode
|
|
|
330 |
@click.command()
|
331 |
|
332 |
# Main options.
|
333 |
+
@click.option(
|
334 |
+
"--run",
|
335 |
+
help="Which function to run",
|
336 |
+
type=click.Choice(
|
337 |
+
["cache-scores", "train-flow", "train-gmm"], case_sensitive=False
|
338 |
+
),
|
339 |
+
)
|
340 |
+
@click.option(
|
341 |
+
"--outdir",
|
342 |
+
help="Where to load/save the results",
|
343 |
+
metavar="DIR",
|
344 |
+
type=str,
|
345 |
+
required=True,
|
346 |
+
)
|
347 |
+
@click.option(
|
348 |
+
"--preset",
|
349 |
+
help="Configuration preset",
|
350 |
+
metavar="STR",
|
351 |
+
type=str,
|
352 |
+
default="edm2-img64-s-fid",
|
353 |
+
show_default=True,
|
354 |
+
)
|
355 |
+
@click.option(
|
356 |
+
"--data", help="Path to the dataset", metavar="ZIP|DIR", type=str, default=None
|
357 |
)
|
|
|
|
|
|
|
358 |
def cmdline(run, outdir, **opts):
|
359 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
360 |
+
preset = opts["preset"]
|
361 |
+
dataset_path = opts["data"]
|
|
|
|
|
|
|
362 |
|
363 |
+
if run in ["cache-scores", "train-flow"]:
|
364 |
+
assert opts["data"] is not None, "Provide path to dataset"
|
365 |
+
|
366 |
if run == "cache-scores":
|
367 |
+
cache_score_norms(
|
368 |
+
preset=preset, dataset_path=dataset_path, outdir=outdir, device=device
|
369 |
+
)
|
370 |
|
371 |
if run == "train-gmm":
|
372 |
train_gmm(
|
|
|
374 |
outdir=f"{outdir}/{preset}",
|
375 |
grid_search=True,
|
376 |
)
|
377 |
+
|
378 |
+
if run == "train-flow":
|
379 |
+
train_flow(dataset_path, outdir=outdir, preset=preset, device=device)
|
380 |
+
test_flow_runner(preset, device=device, load_weights=f"{outdir}/{preset}/flow.pt")
|
381 |
|
|
|
382 |
# train_flow(imagenette_path, preset, device)
|
383 |
|
384 |
# cache_score_norms(
|
|
|
395 |
# nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
|
396 |
# print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|
397 |
|
398 |
+
|
399 |
if __name__ == "__main__":
|
400 |
cmdline()
|