AdcSR / test.py
Guaishou74851's picture
Upload 66 files
34b61ae verified
raw
history blame
2.66 kB
import torch, os, glob, copy
import torch.nn.functional as F
import numpy as np
from PIL import Image
from argparse import ArgumentParser
from torchvision import transforms
from model import Net
parser = ArgumentParser()
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--model_dir", type=str, default="weight")
parser.add_argument("--LR_dir", type=str, default="testset/RealSR/LR")
parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR")
parser.add_argument("--SR_dir", type=str, default="result/RealSR")
args = parser.parse_args()
device = torch.device("cuda")
from diffusers import StableDiffusionPipeline
model_id = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
vae = pipe.vae
tokenizer = pipe.tokenizer
unet = pipe.unet
noise_scheduler = pipe.scheduler
text_encoder = pipe.text_encoder
from diffusers.models.autoencoders.vae import Decoder
ckpt_halfdecoder = torch.load("./weight/pretrained/halfDecoder.ckpt", weights_only=False)
decoder = Decoder(in_channels=4,
out_channels=3,
up_block_types=["UpDecoderBlock2D" for _ in range(4)],
block_out_channels=[64, 128, 256, 256],
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group",
mid_block_add_attention=True).to(device)
decoder_ckpt = {}
for k,v in ckpt_halfdecoder["state_dict"].items():
if "decoder" in k:
new_k = k.replace("decoder.", "")
decoder_ckpt[new_k] = v
decoder.load_state_dict(decoder_ckpt, strict=True)
model = torch.nn.DataParallel(Net(unet, copy.deepcopy(decoder)))
model.load_state_dict(torch.load("./%s/net_params_%d.pkl" % (args.model_dir, args.epoch), weights_only=False))
model = torch.nn.Sequential(
model.module,
*decoder.up_blocks,
decoder.conv_norm_out,
decoder.conv_act,
decoder.conv_out,
).to(device)
test_LR_paths = list(sorted(glob.glob(os.path.join(args.LR_dir, "*.png"))))
test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*.png"))))
os.makedirs(args.SR_dir, exist_ok=True)
with torch.no_grad():
for i, path in enumerate(test_LR_paths):
LR = Image.open(path).convert("RGB")
LR = transforms.ToTensor()(LR).to(device).unsqueeze(0) * 2 - 1
SR = model(LR)
SR = (SR - SR.mean(dim=[2,3],keepdim=True)) / SR.std(dim=[2,3],keepdim=True) \
* LR.std(dim=[2,3],keepdim=True) + LR.mean(dim=[2,3],keepdim=True)
SR = transforms.ToPILImage()((SR[0] / 2 + 0.5).clamp(0, 1).cpu())
SR.save(os.path.join(args.SR_dir, os.path.basename(path)))