ahsanMah's picture
init commit of files
be66f33
raw
history blame
4.88 kB
import os
import pickle
from pickle import dump, load
import numpy as np
import PIL.Image
import torch
from sklearn.mixture import GaussianMixture
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import dnnlib
class EDMScorer(torch.nn.Module):
def __init__(
self,
net,
stop_ratio=0.8, # Maximum ratio of noise levels to compute
num_steps=10, # Number of noise levels to evaluate.
use_fp16=False, # Execute the underlying model at FP16 precision?
sigma_min=0.002, # Minimum supported noise level.
sigma_max=80, # Maximum supported noise level.
sigma_data=0.5, # Expected standard deviation of the training data.
rho=7, # Time step discretization.
device=torch.device("cpu"), # Device to use.
):
super().__init__()
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.net = net.eval()
# Adjust noise levels based on how far we want to accumulate
self.sigma_min = sigma_min
self.sigma_max = sigma_max * stop_ratio
step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
t_steps = (
sigma_max ** (1 / rho)
+ step_indices
/ (num_steps - 1)
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
print("Using steps:", t_steps)
self.register_buffer("sigma_steps", t_steps.to(torch.float64))
@torch.inference_mode()
def forward(
self,
x,
force_fp32=False,
):
x = x.to(torch.float32)
batch_scores = []
for sigma in self.sigma_steps:
xhat = self.net(x, sigma, force_fp32=force_fp32)
c_skip = self.net.sigma_data**2 / (sigma**2 + self.net.sigma_data**2)
score = xhat - (c_skip * x)
# score_norms = score.mean(1)
# score_norms = score.square().sum(dim=(1, 2, 3)) ** 0.5
batch_scores.append(score)
batch_scores = torch.stack(batch_scores, axis=1)
return batch_scores
def build_model(netpath=f"edm2-img64-s-1073741-0.075.pkl", device="cpu"):
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
netpath = f"{model_root}/{netpath}"
with dnnlib.util.open_url(netpath, verbose=1) as f:
data = pickle.load(f)
net = data["ema"]
model = EDMScorer(net, num_steps=20).to(device)
return model
def train_gmm(score_path, outdir="out/msma/"):
X = torch.load(score_path)
gm = GaussianMixture(n_components=5, random_state=42)
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
clf.fit(X)
inlier_nll = -clf.score_samples(X)
with open(f"{outdir}/refscores.npz", "wb") as f:
np.savez_compressed(f, inlier_nll)
with open(f"{outdir}/gmm.pkl", "wb") as f:
dump(clf, f, protocol=5)
def compute_gmm_likelihood(x_score, gmmdir):
with open(f"{gmmdir}/gmm.pkl", "rb") as f:
clf = load(f)
nll = -clf.score_samples(x_score)
with np.load(f"{gmmdir}/refscores.npz", "wb") as f:
ref_nll = f["arr_0"]
percentile = (ref_nll < nll).mean()
return nll, percentile
def test_runner(device="cpu"):
f = "goldfish.JPEG"
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
image = np.array(image)
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
x = torch.from_numpy(image).unsqueeze(0).to(device)
model = build_model(device=device)
scores = model(x)
return scores
def runner(dataset_path, device="cpu"):
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
refimg, reflabel = dsobj[0]
print(refimg.shape, refimg.dtype, reflabel)
dsloader = torch.utils.data.DataLoader(
dsobj, batch_size=48, num_workers=4, prefetch_factor=2
)
model = build_model(device=device)
score_norms = []
for x, _ in tqdm(dsloader):
s = model(x.to(device))
s = s.square().sum(dim=(2, 3, 4)) ** 0.5
score_norms.append(s.cpu())
score_norms = torch.cat(score_norms, dim=0)
os.makedirs("out/msma", exist_ok=True)
with open("out/msma/imagenette64_score_norms.pt", "wb") as f:
torch.save(score_norms, f)
print(f"Computed score norms for {score_norms.shape[0]} samples")
if __name__ == "__main__":
# runner("/GROND_STOR/amahmood/datasets/img64/", device="cuda")
train_gmm("out/msma/imagenette64_score_norms.pt")
s = test_runner(device="cuda")
s = s.square().sum(dim=(2, 3, 4)) ** 0.5
s = s.to("cpu").numpy()
nll, pct = compute_gmm_likelihood(s, gmmdir="out/msma/")
print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")