HighCWu's picture
First model version
052c05b
"""
Approximate the bits/dimension for an image model.
"""
import argparse
import os
import numpy as np
import torch.distributed as dist
from pixel_guide_diffusion import dist_util, logger
from pixel_guide_diffusion.image_datasets import load_data
from pixel_guide_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
def main():
args = create_argparser().parse_args()
dist_util.setup_dist()
logger.configure()
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
model.eval()
logger.log("creating data loader...")
data = load_data(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond,
deterministic=True,
)
logger.log("evaluating...")
run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised)
def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised):
all_bpd = []
all_metrics = {"vb": [], "mse": [], "xstart_mse": []}
num_complete = 0
while num_complete < num_samples:
batch, model_kwargs = next(data)
batch = batch.to(dist_util.dev())
model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
minibatch_metrics = diffusion.calc_bpd_loop(
model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
for key, term_list in all_metrics.items():
terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size()
dist.all_reduce(terms)
term_list.append(terms.detach().cpu().numpy())
total_bpd = minibatch_metrics["total_bpd"]
total_bpd = total_bpd.mean() / dist.get_world_size()
dist.all_reduce(total_bpd)
all_bpd.append(total_bpd.item())
num_complete += dist.get_world_size() * batch.shape[0]
logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}")
if dist.get_rank() == 0:
for name, terms in all_metrics.items():
out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz")
logger.log(f"saving {name} terms to {out_path}")
np.savez(out_path, np.mean(np.stack(terms), axis=0))
dist.barrier()
logger.log("evaluation complete")
def create_argparser():
defaults = dict(
data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path=""
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()