Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import pickle | |
from pickle import dump, load | |
import numpy as np | |
import PIL.Image | |
import torch | |
from sklearn.mixture import GaussianMixture | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import StandardScaler | |
from tqdm import tqdm | |
import dnnlib | |
class EDMScorer(torch.nn.Module): | |
def __init__( | |
self, | |
net, | |
stop_ratio=0.8, # Maximum ratio of noise levels to compute | |
num_steps=10, # Number of noise levels to evaluate. | |
use_fp16=False, # Execute the underlying model at FP16 precision? | |
sigma_min=0.002, # Minimum supported noise level. | |
sigma_max=80, # Maximum supported noise level. | |
sigma_data=0.5, # Expected standard deviation of the training data. | |
rho=7, # Time step discretization. | |
device=torch.device("cpu"), # Device to use. | |
): | |
super().__init__() | |
self.use_fp16 = use_fp16 | |
self.sigma_min = sigma_min | |
self.sigma_max = sigma_max | |
self.sigma_data = sigma_data | |
self.net = net.eval() | |
# Adjust noise levels based on how far we want to accumulate | |
self.sigma_min = sigma_min | |
self.sigma_max = sigma_max * stop_ratio | |
step_indices = torch.arange(num_steps, dtype=torch.float64, device=device) | |
t_steps = ( | |
sigma_max ** (1 / rho) | |
+ step_indices | |
/ (num_steps - 1) | |
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) | |
) ** rho | |
print("Using steps:", t_steps) | |
self.register_buffer("sigma_steps", t_steps.to(torch.float64)) | |
def forward( | |
self, | |
x, | |
force_fp32=False, | |
): | |
x = x.to(torch.float32) | |
batch_scores = [] | |
for sigma in self.sigma_steps: | |
xhat = self.net(x, sigma, force_fp32=force_fp32) | |
c_skip = self.net.sigma_data**2 / (sigma**2 + self.net.sigma_data**2) | |
score = xhat - (c_skip * x) | |
# score_norms = score.mean(1) | |
# score_norms = score.square().sum(dim=(1, 2, 3)) ** 0.5 | |
batch_scores.append(score) | |
batch_scores = torch.stack(batch_scores, axis=1) | |
return batch_scores | |
def build_model(netpath=f"edm2-img64-s-1073741-0.075.pkl", device="cpu"): | |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions" | |
netpath = f"{model_root}/{netpath}" | |
with dnnlib.util.open_url(netpath, verbose=1) as f: | |
data = pickle.load(f) | |
net = data["ema"] | |
model = EDMScorer(net, num_steps=20).to(device) | |
return model | |
def train_gmm(score_path, outdir="out/msma/"): | |
X = torch.load(score_path) | |
gm = GaussianMixture(n_components=5, random_state=42) | |
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)]) | |
clf.fit(X) | |
inlier_nll = -clf.score_samples(X) | |
with open(f"{outdir}/refscores.npz", "wb") as f: | |
np.savez_compressed(f, inlier_nll) | |
with open(f"{outdir}/gmm.pkl", "wb") as f: | |
dump(clf, f, protocol=5) | |
def compute_gmm_likelihood(x_score, gmmdir): | |
with open(f"{gmmdir}/gmm.pkl", "rb") as f: | |
clf = load(f) | |
nll = -clf.score_samples(x_score) | |
with np.load(f"{gmmdir}/refscores.npz", "wb") as f: | |
ref_nll = f["arr_0"] | |
percentile = (ref_nll < nll).mean() | |
return nll, percentile | |
def test_runner(device="cpu"): | |
f = "goldfish.JPEG" | |
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS) | |
image = np.array(image) | |
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) | |
x = torch.from_numpy(image).unsqueeze(0).to(device) | |
model = build_model(device=device) | |
scores = model(x) | |
return scores | |
def runner(dataset_path, device="cpu"): | |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64) | |
refimg, reflabel = dsobj[0] | |
print(refimg.shape, refimg.dtype, reflabel) | |
dsloader = torch.utils.data.DataLoader( | |
dsobj, batch_size=48, num_workers=4, prefetch_factor=2 | |
) | |
model = build_model(device=device) | |
score_norms = [] | |
for x, _ in tqdm(dsloader): | |
s = model(x.to(device)) | |
s = s.square().sum(dim=(2, 3, 4)) ** 0.5 | |
score_norms.append(s.cpu()) | |
score_norms = torch.cat(score_norms, dim=0) | |
os.makedirs("out/msma", exist_ok=True) | |
with open("out/msma/imagenette64_score_norms.pt", "wb") as f: | |
torch.save(score_norms, f) | |
print(f"Computed score norms for {score_norms.shape[0]} samples") | |
if __name__ == "__main__": | |
# runner("/GROND_STOR/amahmood/datasets/img64/", device="cuda") | |
train_gmm("out/msma/imagenette64_score_norms.pt") | |
s = test_runner(device="cuda") | |
s = s.square().sum(dim=(2, 3, 4)) ** 0.5 | |
s = s.to("cpu").numpy() | |
nll, pct = compute_gmm_likelihood(s, gmmdir="out/msma/") | |
print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile") | |