|
import torch, os, glob, copy |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from PIL import Image |
|
from argparse import ArgumentParser |
|
from torchvision import transforms |
|
from model import Net |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--epoch", type=int, default=200) |
|
parser.add_argument("--model_dir", type=str, default="weight") |
|
parser.add_argument("--LR_dir", type=str, default="testset/RealSR/LR") |
|
parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR") |
|
parser.add_argument("--SR_dir", type=str, default="result/RealSR") |
|
args = parser.parse_args() |
|
|
|
device = torch.device("cuda") |
|
|
|
from diffusers import StableDiffusionPipeline |
|
model_id = "stabilityai/stable-diffusion-2-1-base" |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device) |
|
|
|
vae = pipe.vae |
|
tokenizer = pipe.tokenizer |
|
unet = pipe.unet |
|
noise_scheduler = pipe.scheduler |
|
text_encoder = pipe.text_encoder |
|
|
|
from diffusers.models.autoencoders.vae import Decoder |
|
ckpt_halfdecoder = torch.load("./weight/pretrained/halfDecoder.ckpt", weights_only=False) |
|
decoder = Decoder(in_channels=4, |
|
out_channels=3, |
|
up_block_types=["UpDecoderBlock2D" for _ in range(4)], |
|
block_out_channels=[64, 128, 256, 256], |
|
layers_per_block=2, |
|
norm_num_groups=32, |
|
act_fn="silu", |
|
norm_type="group", |
|
mid_block_add_attention=True).to(device) |
|
decoder_ckpt = {} |
|
for k,v in ckpt_halfdecoder["state_dict"].items(): |
|
if "decoder" in k: |
|
new_k = k.replace("decoder.", "") |
|
decoder_ckpt[new_k] = v |
|
decoder.load_state_dict(decoder_ckpt, strict=True) |
|
|
|
model = torch.nn.DataParallel(Net(unet, copy.deepcopy(decoder))) |
|
model.load_state_dict(torch.load("./%s/net_params_%d.pkl" % (args.model_dir, args.epoch), weights_only=False)) |
|
model = torch.nn.Sequential( |
|
model.module, |
|
*decoder.up_blocks, |
|
decoder.conv_norm_out, |
|
decoder.conv_act, |
|
decoder.conv_out, |
|
).to(device) |
|
|
|
test_LR_paths = list(sorted(glob.glob(os.path.join(args.LR_dir, "*.png")))) |
|
test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*.png")))) |
|
|
|
os.makedirs(args.SR_dir, exist_ok=True) |
|
|
|
with torch.no_grad(): |
|
for i, path in enumerate(test_LR_paths): |
|
LR = Image.open(path).convert("RGB") |
|
LR = transforms.ToTensor()(LR).to(device).unsqueeze(0) * 2 - 1 |
|
SR = model(LR) |
|
SR = (SR - SR.mean(dim=[2,3],keepdim=True)) / SR.std(dim=[2,3],keepdim=True) \ |
|
* LR.std(dim=[2,3],keepdim=True) + LR.mean(dim=[2,3],keepdim=True) |
|
SR = transforms.ToPILImage()((SR[0] / 2 + 0.5).clamp(0, 1).cpu()) |
|
SR.save(os.path.join(args.SR_dir, os.path.basename(path))) |
|
|