ahsanMah commited on
Commit
bf573cf
·
1 Parent(s): b1602ac

+ added cmd line to msma

Browse files
Files changed (2) hide show
  1. app.py +33 -13
  2. msma.py +44 -18
app.py CHANGED
@@ -6,12 +6,14 @@ import matplotlib.pyplot as plt
6
  import numpy as np
7
  import torch
8
 
9
- from msma import build_model, config_presets
10
 
11
 
12
  @cache
13
- def load_model(preset="edm2-img64-s-fid", device='cpu'):
14
- return build_model(preset, device)
 
 
15
 
16
  @cache
17
  def load_reference_scores(model_dir):
@@ -38,24 +40,42 @@ def plot_against_reference(nll, ref_nll):
38
  return fig
39
 
40
 
41
- def run_inference(img, preset="edm2-img64-s-fid", device="cuda"):
42
- img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
43
- img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
44
- model = load_model(preset=preset, device=device)
45
- x = model(img.cuda())
46
- x = x.square().sum(dim=(2, 3, 4)) ** 0.5
47
- nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
48
 
49
- plot = plot_against_reference(nll, ref_nll)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
52
- return outstr, plot
 
 
 
53
 
54
 
55
  demo = gr.Interface(
56
  fn=run_inference,
57
  inputs=["image"],
58
- outputs=["text", gr.Plot(label="Comparing to Imagenette")],
 
 
 
59
  )
60
 
61
  if __name__ == "__main__":
 
6
  import numpy as np
7
  import torch
8
 
9
+ from msma import ScoreFlow, config_presets
10
 
11
 
12
  @cache
13
+ def load_model(modeldir, preset="edm2-img64-s-fid", device='cpu', outdir=None):
14
+ model = ScoreFlow(preset, device=device)
15
+ model.flow.load_state_dict(torch.load(f"{modeldir}/{preset}/flow.pt"))
16
+ return model
17
 
18
  @cache
19
  def load_reference_scores(model_dir):
 
40
  return fig
41
 
42
 
43
+ def plot_heatmap(heatmap):
44
+ fig, ax = plt.subplots()
45
+ im = heatmap[0,0]
46
+ ax.imshow(im, cmap='gist_heat')
47
+ fig.tight_layout()
48
+ return fig
 
49
 
50
+ # def compute_scores
51
 
52
+
53
+ def run_inference(img, preset="edm2-img64-s-fid", device="cuda"):
54
+
55
+ with torch.inference_mode():
56
+ img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
57
+ img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
58
+ img = img.to(device)
59
+ model = load_model(modeldir='models', preset=preset, device=device)
60
+ x = model.scorenet(img)
61
+ x = x.square().sum(dim=(2, 3, 4)) ** 0.5
62
+ img_likelihood = model(img).cpu().numpy()
63
+ nll, pct, ref_nll = compute_gmm_likelihood(x.cpu(), model_dir=f"models/{preset}")
64
+
65
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
66
+ histplot = plot_against_reference(nll, ref_nll)
67
+ heatmapplot = plot_heatmap(img_likelihood)
68
+
69
+ return outstr, heatmapplot, histplot
70
 
71
 
72
  demo = gr.Interface(
73
  fn=run_inference,
74
  inputs=["image"],
75
+ outputs=["text",
76
+ gr.Plot(label="Anomaly Heatmap"),
77
+ gr.Plot(label="Comparing to Imagenette"),
78
+ ],
79
  )
80
 
81
  if __name__ == "__main__":
msma.py CHANGED
@@ -3,6 +3,7 @@ import pickle
3
  from functools import partial
4
  from pickle import dump, load
5
 
 
6
  import numpy as np
7
  import PIL.Image
8
  import torch
@@ -95,12 +96,12 @@ class EDMScorer(torch.nn.Module):
95
  class ScoreFlow(torch.nn.Module):
96
  def __init__(
97
  self,
98
- scorenet,
99
- vectorize=False,
100
  device="cpu",
101
  ):
102
  super().__init__()
103
 
 
104
  h = w = scorenet.net.img_resolution
105
  c = scorenet.net.img_channels
106
  num_sigmas = len(scorenet.sigma_steps)
@@ -134,9 +135,9 @@ def train_gmm(score_path, outdir, grid_search=False):
134
  gm = GaussianMixture(
135
  n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
136
  )
 
137
 
138
  if grid_search:
139
- clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
140
  param_grid = dict(
141
  GMM__n_components=range(2, 11, 1),
142
  )
@@ -184,10 +185,11 @@ def compute_gmm_likelihood(x_score, gmmdir):
184
  return nll, percentile
185
 
186
 
187
- def cache_score_norms(preset, dataset_path, device="cpu"):
188
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
189
  refimg, reflabel = dsobj[0]
190
- print(refimg.shape, refimg.dtype, reflabel)
 
191
  dsloader = torch.utils.data.DataLoader(
192
  dsobj, batch_size=48, num_workers=4, prefetch_factor=2
193
  )
@@ -202,8 +204,8 @@ def cache_score_norms(preset, dataset_path, device="cpu"):
202
 
203
  score_norms = torch.cat(score_norms, dim=0)
204
 
205
- os.makedirs("out/msma", exist_ok=True)
206
- with open(f"out/msma/{preset}_imagenette_score_norms.pt", "wb") as f:
207
  torch.save(score_norms, f)
208
 
209
  print(f"Computed score norms for {score_norms.shape[0]} samples")
@@ -232,7 +234,7 @@ def train_flow(dataset_path, preset, device="cuda"):
232
  val_ds, batch_size=48, num_workers=4, prefetch_factor=2
233
  )
234
 
235
- model = ScoreFlow(build_model(preset=preset), device=device)
236
  opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
237
  train_step = partial(
238
  PatchFlow.stochastic_step,
@@ -296,16 +298,15 @@ def test_runner(device="cpu"):
296
  return scores
297
 
298
 
299
- def test_flow_runner(device="cpu", load_weights=None):
300
- f = "doge.jpg"
301
- # f = "goldfish.JPEG"
302
  image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
303
  image = np.array(image)
304
  image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
305
  x = torch.from_numpy(image).unsqueeze(0).to(device)
306
- model = build_model(device=device)
307
 
308
- score_flow = ScoreFlow(scorenet=model, device=device)
309
 
310
  if load_weights is not None:
311
  score_flow.flow.load_state_dict(torch.load(load_weights))
@@ -323,13 +324,35 @@ def test_flow_runner(device="cpu", load_weights=None):
323
  return
324
 
325
 
326
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
327
  device = "cuda" if torch.cuda.is_available() else "cpu"
328
- preset = "edm2-img64-s-fid"
329
- imagenette_path = "/GROND_STOR/amahmood/datasets/img64/"
 
 
 
330
 
331
- train_flow(imagenette_path, preset, device)
332
- test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
 
 
 
 
 
 
 
 
 
 
333
 
334
  # cache_score_norms(
335
  # preset=preset,
@@ -344,3 +367,6 @@ if __name__ == "__main__":
344
  # s = s.to("cpu").numpy()
345
  # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
346
  # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
 
 
 
 
3
  from functools import partial
4
  from pickle import dump, load
5
 
6
+ import click
7
  import numpy as np
8
  import PIL.Image
9
  import torch
 
96
  class ScoreFlow(torch.nn.Module):
97
  def __init__(
98
  self,
99
+ preset,
 
100
  device="cpu",
101
  ):
102
  super().__init__()
103
 
104
+ scorenet = build_model(preset)
105
  h = w = scorenet.net.img_resolution
106
  c = scorenet.net.img_channels
107
  num_sigmas = len(scorenet.sigma_steps)
 
135
  gm = GaussianMixture(
136
  n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
137
  )
138
+ clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
139
 
140
  if grid_search:
 
141
  param_grid = dict(
142
  GMM__n_components=range(2, 11, 1),
143
  )
 
185
  return nll, percentile
186
 
187
 
188
+ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
189
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
190
  refimg, reflabel = dsobj[0]
191
+ print(f"Loading dataset from {dataset_path}")
192
+ print(f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}")
193
  dsloader = torch.utils.data.DataLoader(
194
  dsobj, batch_size=48, num_workers=4, prefetch_factor=2
195
  )
 
204
 
205
  score_norms = torch.cat(score_norms, dim=0)
206
 
207
+ os.makedirs(f"{outdir}/{preset}/", exist_ok=True)
208
+ with open(f"{outdir}/{preset}/imagenette_score_norms.pt", "wb") as f:
209
  torch.save(score_norms, f)
210
 
211
  print(f"Computed score norms for {score_norms.shape[0]} samples")
 
234
  val_ds, batch_size=48, num_workers=4, prefetch_factor=2
235
  )
236
 
237
+ model = ScoreFlow(preset, device=device)
238
  opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
239
  train_step = partial(
240
  PatchFlow.stochastic_step,
 
298
  return scores
299
 
300
 
301
+ def test_flow_runner(preset, device="cpu", load_weights=None):
302
+ # f = "doge.jpg"
303
+ f = "goldfish.JPEG"
304
  image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
305
  image = np.array(image)
306
  image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
307
  x = torch.from_numpy(image).unsqueeze(0).to(device)
 
308
 
309
+ score_flow = ScoreFlow(preset, device=device)
310
 
311
  if load_weights is not None:
312
  score_flow.flow.load_state_dict(torch.load(load_weights))
 
324
  return
325
 
326
 
327
+ @click.command()
328
+
329
+ # Main options.
330
+ @click.option('--run', help='Which function to run',
331
+ type=click.Choice(['cache-scores', 'train-flow', 'train-gmm'], case_sensitive=False)
332
+ )
333
+ @click.option('--outdir', help='Where to load/save the results', metavar='DIR', type=str, required=True)
334
+ @click.option('--preset', help='Configuration preset', metavar='STR', type=str, default='edm2-img64-s-fid', show_default=True)
335
+ @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, default=None)
336
+ def cmdline(run, outdir, **opts):
337
  device = "cuda" if torch.cuda.is_available() else "cpu"
338
+ preset = opts['preset']
339
+ dataset_path = opts['data']
340
+
341
+ if run in ['cache-scores', 'train-flow']:
342
+ assert opts['data'] is not None, "Provide path to dataset"
343
 
344
+ if run == "cache-scores":
345
+ cache_score_norms(preset=preset, dataset_path=dataset_path, outdir=outdir, device=device)
346
+
347
+ if run == "train-gmm":
348
+ train_gmm(
349
+ score_path=f"{outdir}/{preset}/imagenette_score_norms.pt",
350
+ outdir=f"{outdir}/{preset}",
351
+ grid_search=True,
352
+ )
353
+
354
+ # test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
355
+ # train_flow(imagenette_path, preset, device)
356
 
357
  # cache_score_norms(
358
  # preset=preset,
 
367
  # s = s.to("cpu").numpy()
368
  # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
369
  # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
370
+
371
+ if __name__ == "__main__":
372
+ cmdline()