ahsanMah commited on
Commit
ffaef20
·
1 Parent(s): 7358cfe

+ loading from hub is optional

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -15,10 +15,12 @@ from msma import ScoreFlow, build_model_from_pickle, config_presets
15
 
16
  @cache
17
  def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu"):
18
- model = ScoreFlow(preset, num_flows=8, device=device)
19
- model.flow.load_state_dict(torch.load(f"{modeldir}/nb8/{preset}/flow.pt"))
 
20
  return model
21
 
 
22
  @cache
23
  def load_model_from_hub(preset, device):
24
  scorenet = build_model_from_pickle(preset)
@@ -40,7 +42,7 @@ def load_model_from_hub(preset, device):
40
  cache_dir="/tmp/",
41
  )
42
 
43
- model = ScoreFlow(scorenet, device=device, **model_params['PatchFlow'])
44
  model.load_state_dict(load_file(hf_checkpoint), strict=True)
45
 
46
  return model
@@ -91,7 +93,7 @@ def plot_heatmap(img: Image, heatmap: np.array):
91
  return im
92
 
93
 
94
- def run_inference(input_img, preset="edm2-img64-s-fid"):
95
 
96
  device = "cuda" if torch.cuda.is_available() else "cpu"
97
  # img = center_crop_imagenet(64, img)
@@ -101,11 +103,12 @@ def run_inference(input_img, preset="edm2-img64-s-fid"):
101
  img = np.array(input_img)
102
  img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
103
  img = img.float().to(device)
104
- # model = load_model(modeldir="models", preset=preset, device=device)
105
- model = load_model_from_hub(preset=preset, device=device)
 
 
 
106
  img_likelihood = model(img).cpu().numpy()
107
- # img_likelihood = model.scorenet(img).square().sum(1).sum(1).contiguous().float().cpu().unsqueeze(1).numpy()
108
- # print(img_likelihood.shape, img_likelihood.dtype)
109
  img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
110
  x = model.scorenet(img)
111
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
@@ -124,14 +127,27 @@ demo = gr.Interface(
124
  fn=run_inference,
125
  inputs=[
126
  gr.Image(type="pil", label="Input Image"),
127
- gr.Dropdown(choices=config_presets.keys(), label="Score Model"),
 
 
 
 
 
 
 
 
 
128
  ],
129
  outputs=[
130
  "text",
131
  gr.Image(label="Anomaly Heatmap", min_width=64),
132
  gr.Plot(label="Comparing to Imagenette"),
133
  ],
134
- examples=[["goldfish.JPEG", "edm2-img64-s-fid"]],
 
 
 
 
135
  )
136
 
137
  if __name__ == "__main__":
 
15
 
16
  @cache
17
  def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu"):
18
+ scorenet = build_model_from_pickle(preset=preset)
19
+ model = ScoreFlow(scorenet, num_flows=8, device=device)
20
+ model.flow.load_state_dict(torch.load(f"{modeldir}/comb/{preset}/flow.pt"))
21
  return model
22
 
23
+
24
  @cache
25
  def load_model_from_hub(preset, device):
26
  scorenet = build_model_from_pickle(preset)
 
42
  cache_dir="/tmp/",
43
  )
44
 
45
+ model = ScoreFlow(scorenet, device=device, **model_params["PatchFlow"])
46
  model.load_state_dict(load_file(hf_checkpoint), strict=True)
47
 
48
  return model
 
93
  return im
94
 
95
 
96
+ def run_inference(input_img, preset="edm2-img64-s-fid", load_from_hub=False):
97
 
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
  # img = center_crop_imagenet(64, img)
 
103
  img = np.array(input_img)
104
  img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
105
  img = img.float().to(device)
106
+ if load_from_hub:
107
+ model = load_model_from_hub(preset=preset, device=device)
108
+ else:
109
+ model = load_model(modeldir="models", preset=preset, device=device)
110
+
111
  img_likelihood = model(img).cpu().numpy()
 
 
112
  img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
113
  x = model.scorenet(img)
114
  x = x.square().sum(dim=(2, 3, 4)) ** 0.5
 
127
  fn=run_inference,
128
  inputs=[
129
  gr.Image(type="pil", label="Input Image"),
130
+ gr.Dropdown(
131
+ choices=config_presets.keys(),
132
+ label="Score Model Preset",
133
+ info="The preset of the underlying score estimator. These are the EDM2 diffusion models from Karras et.al.",
134
+ ),
135
+ gr.Checkbox(
136
+ label="HuggingFace Hub",
137
+ value=True,
138
+ info="Load a pretrained model from HuggingFace. Uncheck to use a model from `models` directory.",
139
+ ),
140
  ],
141
  outputs=[
142
  "text",
143
  gr.Image(label="Anomaly Heatmap", min_width=64),
144
  gr.Plot(label="Comparing to Imagenette"),
145
  ],
146
+ examples=[
147
+ ["samples/duckelephant.jpeg", "edm2-img64-s-fid", True],
148
+ ["samples/sharkhorse.jpeg", "edm2-img64-s-fid", True],
149
+ ["samples/goldfish.jpeg", "edm2-img64-s-fid", True],
150
+ ],
151
  )
152
 
153
  if __name__ == "__main__":