ahsanMah commited on
Commit
f1e86bd
·
1 Parent(s): 850e111

factoring out inference function

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -68,7 +68,7 @@ def compute_gmm_likelihood(x_score, model_dir):
68
 
69
  def plot_against_reference(nll, ref_nll):
70
  fig, ax = plt.subplots()
71
- ax.hist(ref_nll, label="Reference Scores")
72
  ax.axvline(nll, label="Image Score", c="red", ls="--")
73
  plt.legend()
74
  fig.tight_layout()
@@ -93,7 +93,15 @@ def plot_heatmap(img: Image, heatmap: np.array):
93
  return im
94
 
95
 
96
- def run_inference(input_img, preset="edm2-img64-s-fid", load_from_hub=False):
 
 
 
 
 
 
 
 
97
 
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
  # img = center_crop_imagenet(64, img)
@@ -108,12 +116,9 @@ def run_inference(input_img, preset="edm2-img64-s-fid", load_from_hub=False):
108
  else:
109
  model = load_model(modeldir="models", preset=preset, device=device)
110
 
111
- img_likelihood = model(img).cpu().numpy()
112
- img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
113
- x = model.scorenet(img)
114
- x = x.square().sum(dim=(2, 3, 4)) ** 0.5
115
  nll, pct, ref_nll = compute_gmm_likelihood(
116
- x.cpu(), model_dir=f"models/{preset}"
117
  )
118
 
119
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
@@ -124,7 +129,7 @@ def run_inference(input_img, preset="edm2-img64-s-fid", load_from_hub=False):
124
 
125
 
126
  demo = gr.Interface(
127
- fn=run_inference,
128
  inputs=[
129
  gr.Image(type="pil", label="Input Image"),
130
  gr.Dropdown(
@@ -139,7 +144,7 @@ demo = gr.Interface(
139
  ),
140
  ],
141
  outputs=[
142
- "text",
143
  gr.Image(label="Anomaly Heatmap", min_width=64),
144
  gr.Plot(label="Comparing to Imagenette"),
145
  ],
 
68
 
69
  def plot_against_reference(nll, ref_nll):
70
  fig, ax = plt.subplots()
71
+ ax.hist(ref_nll, label="Reference Scores", bins=25)
72
  ax.axvline(nll, label="Image Score", c="red", ls="--")
73
  plt.legend()
74
  fig.tight_layout()
 
93
  return im
94
 
95
 
96
+ def run_inference(model, img):
97
+ img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
98
+ score_norms = model.scorenet(img)
99
+ score_norms = score_norms.square().sum(dim=(2, 3, 4)) ** 0.5
100
+ img_likelihood = model(img).cpu().numpy()
101
+ score_norms = score_norms.cpu().numpy()
102
+ return img_likelihood, score_norms
103
+
104
+ def localize_anomalies(input_img, preset="edm2-img64-s-fid", load_from_hub=False):
105
 
106
  device = "cuda" if torch.cuda.is_available() else "cpu"
107
  # img = center_crop_imagenet(64, img)
 
116
  else:
117
  model = load_model(modeldir="models", preset=preset, device=device)
118
 
119
+ img_likelihood, score_norms = run_inference(model, img)
 
 
 
120
  nll, pct, ref_nll = compute_gmm_likelihood(
121
+ score_norms, model_dir=f"models/{preset}"
122
  )
123
 
124
  outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
 
129
 
130
 
131
  demo = gr.Interface(
132
+ fn=localize_anomalies,
133
  inputs=[
134
  gr.Image(type="pil", label="Input Image"),
135
  gr.Dropdown(
 
144
  ),
145
  ],
146
  outputs=[
147
+ gr.Text(label="Estimated global outlier scores - Percentiles with respect to Imagenette Scores"),
148
  gr.Image(label="Anomaly Heatmap", min_width=64),
149
  gr.Plot(label="Comparing to Imagenette"),
150
  ],