ahsanMah commited on
Commit
52f9197
·
1 Parent(s): 69e2af3

+ caching model

Browse files

+ displaying basic hist plot

Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -1,38 +1,63 @@
 
1
  from pickle import load
2
 
3
  import gradio as gr
 
4
  import numpy as np
5
  import torch
6
 
7
  from scorer import build_model
8
 
9
 
 
 
 
 
 
 
 
 
 
 
10
  def compute_gmm_likelihood(x_score, gmmdir='models'):
11
  with open(f"{gmmdir}/gmm.pkl", "rb") as f:
12
  clf = load(f)
13
  nll = -clf.score(x_score)
14
 
15
- with np.load(f"{gmmdir}/refscores.npz", "wb") as f:
16
- ref_nll = f["arr_0"]
17
- percentile = (ref_nll < nll).mean() * 100
18
 
19
  return nll, percentile
20
 
21
- def run_inference(img):
 
 
 
 
 
 
 
 
 
 
22
  img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
23
  img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
24
- model = build_model(device='cuda')
25
  x = model(img.cuda())
26
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
27
  nll, pct = compute_gmm_likelihood(x.cpu())
28
 
29
- return f"Image of shape: {img.shape} -> {nll:.3f}@{pct:.2f}"
 
 
 
30
 
31
 
32
  demo = gr.Interface(
33
  fn=run_inference,
34
  inputs=["image"],
35
- outputs=["text"],
36
  )
37
 
38
- demo.launch()
 
 
1
+ from functools import cache
2
  from pickle import load
3
 
4
  import gradio as gr
5
+ import matplotlib.pyplot as plt
6
  import numpy as np
7
  import torch
8
 
9
  from scorer import build_model
10
 
11
 
12
+ @cache
13
+ def load_model(device):
14
+ return build_model(device=device)
15
+
16
+ @cache
17
+ def load_reference_scores(gmmdir='models'):
18
+ with np.load(f"{gmmdir}/refscores.npz", "rb") as f:
19
+ ref_nll = f["arr_0"]
20
+ return ref_nll
21
+
22
  def compute_gmm_likelihood(x_score, gmmdir='models'):
23
  with open(f"{gmmdir}/gmm.pkl", "rb") as f:
24
  clf = load(f)
25
  nll = -clf.score(x_score)
26
 
27
+ ref_nll = load_reference_scores(gmmdir)
28
+ percentile = (ref_nll < nll).mean() * 100
 
29
 
30
  return nll, percentile
31
 
32
+ def plot_against_reference(nll):
33
+ ref_nll = load_reference_scores()
34
+ print(ref_nll.shape)
35
+ fig, ax = plt.subplots()
36
+ ax.hist(ref_nll)
37
+ ax.axvline(nll, label='Image Score', c='red', ls="--")
38
+ plt.legend()
39
+ fig.tight_layout()
40
+ return fig
41
+
42
+ def run_inference(img, device='cuda'):
43
  img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
44
  img = torch.nn.functional.interpolate(img, size=64, mode='bilinear')
45
+ model = load_model(device=device)
46
  x = model(img.cuda())
47
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
48
  nll, pct = compute_gmm_likelihood(x.cpu())
49
 
50
+ plot = plot_against_reference(nll)
51
+ print(plot)
52
+ outstr = f"Anomaly score: {nll:.3f} -> {pct:.2f} percentile"
53
+ return outstr, plot
54
 
55
 
56
  demo = gr.Interface(
57
  fn=run_inference,
58
  inputs=["image"],
59
+ outputs=["text", gr.Plot()],
60
  )
61
 
62
+ if __name__ == "__main__":
63
+ demo.launch()