ahsanMah commited on
Commit
3c00c76
·
1 Parent(s): deb6231

trying to separate out hf app

Browse files
Files changed (2) hide show
  1. app.py +32 -26
  2. hfapp.py +55 -0
app.py CHANGED
@@ -127,33 +127,39 @@ def localize_anomalies(input_img, preset="edm2-img64-s-fid", load_from_hub=False
127
 
128
  return outstr, heatmapplot, histplot
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- demo = gr.Interface(
132
- fn=localize_anomalies,
133
- inputs=[
134
- gr.Image(type="pil", label="Input Image"),
135
- gr.Dropdown(
136
- choices=config_presets.keys(),
137
- label="Score Model Preset",
138
- info="The preset of the underlying score estimator. These are the EDM2 diffusion models from Karras et.al.",
139
- ),
140
- gr.Checkbox(
141
- label="HuggingFace Hub",
142
- value=True,
143
- info="Load a pretrained model from HuggingFace. Uncheck to use a model from `models` directory.",
144
- ),
145
- ],
146
- outputs=[
147
- gr.Text(label="Estimated global outlier scores - Percentiles with respect to Imagenette Scores"),
148
- gr.Image(label="Anomaly Heatmap", min_width=64),
149
- gr.Plot(label="Comparing to Imagenette"),
150
- ],
151
- examples=[
152
- ["samples/duckelephant.jpeg", "edm2-img64-s-fid", True],
153
- ["samples/sharkhorse.jpeg", "edm2-img64-s-fid", True],
154
- ["samples/goldfish.jpeg", "edm2-img64-s-fid", True],
155
- ],
156
- )
157
 
 
158
  if __name__ == "__main__":
159
  demo.launch()
 
127
 
128
  return outstr, heatmapplot, histplot
129
 
130
+ def build_demo(inference_fn):
131
+
132
+ demo = gr.Interface(
133
+ fn=inference_fn,
134
+ inputs=[
135
+ gr.Image(type="pil", label="Input Image"),
136
+ gr.Dropdown(
137
+ choices=config_presets.keys(),
138
+ label="Score Model Preset",
139
+ info="The preset of the underlying score estimator. These are the EDM2 diffusion models from Karras et.al.",
140
+ ),
141
+ gr.Checkbox(
142
+ label="HuggingFace Hub",
143
+ value=True,
144
+ info="Load a pretrained model from HuggingFace. Uncheck to use a model from `models` directory.",
145
+ ),
146
+ ],
147
+ outputs=[
148
+ gr.Text(
149
+ label="Estimated global outlier scores - Percentiles with respect to Imagenette Scores"
150
+ ),
151
+ gr.Image(label="Anomaly Heatmap", min_width=64),
152
+ gr.Plot(label="Comparing to Imagenette"),
153
+ ],
154
+ examples=[
155
+ ["samples/duckelephant.jpeg", "edm2-img64-s-fid", True],
156
+ ["samples/sharkhorse.jpeg", "edm2-img64-s-fid", True],
157
+ ["samples/goldfish.jpeg", "edm2-img64-s-fid", True],
158
+ ],
159
+ )
160
 
161
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ demo = build_demo(localize_anomalies)
164
  if __name__ == "__main__":
165
  demo.launch()
hfapp.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL.Image as Image
3
+ import spaces
4
+ import torch
5
+
6
+ from app import (
7
+ build_demo,
8
+ compute_gmm_likelihood,
9
+ load_model_from_hub,
10
+ plot_against_reference,
11
+ plot_heatmap,
12
+ )
13
+
14
+
15
+ @spaces.GPU
16
+ def run_inference(model, img):
17
+ img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
18
+ score_norms = model.scorenet(img)
19
+ score_norms = score_norms.square().sum(dim=(2, 3, 4)) ** 0.5
20
+ img_likelihood = model(img).cpu().numpy()
21
+ score_norms = score_norms.cpu().numpy()
22
+ return img_likelihood, score_norms
23
+
24
+
25
+ def localize_anomalies(input_img, preset="edm2-img64-s-fid", load_from_hub=False):
26
+ print("NEW LOCALIZE")
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ # img = center_crop_imagenet(64, img)
29
+ input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
30
+
31
+ with torch.inference_mode():
32
+ img = np.array(input_img)
33
+ img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
34
+ img = img.float().to(device)
35
+ if load_from_hub:
36
+ model = load_model_from_hub(preset=preset, device=device)
37
+ else:
38
+ model = load_model(modeldir="models", preset=preset, device=device)
39
+
40
+ img_likelihood, score_norms = run_inference(model, img)
41
+ nll, pct, ref_nll = compute_gmm_likelihood(
42
+ score_norms, model_dir=f"models/{preset}"
43
+ )
44
+
45
+ outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
46
+ histplot = plot_against_reference(nll, ref_nll)
47
+ heatmapplot = plot_heatmap(input_img, img_likelihood)
48
+
49
+ return outstr, heatmapplot, histplot
50
+
51
+
52
+
53
+ demo = build_demo(localize_anomalies)
54
+ if __name__ == "__main__":
55
+ demo.launch()