|
import os
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
from deepfillv2 import network
|
|
import skimage
|
|
|
|
from config import GPU_DEVICE
|
|
|
|
|
|
|
|
|
|
|
|
def create_generator(opt):
|
|
|
|
generator = network.GatedGenerator(opt)
|
|
print("-- Generator is created! --")
|
|
network.weights_init(
|
|
generator, init_type=opt.init_type, init_gain=opt.init_gain
|
|
)
|
|
print("-- Initialized generator with %s type --" % opt.init_type)
|
|
return generator
|
|
|
|
|
|
def create_discriminator(opt):
|
|
|
|
discriminator = network.PatchDiscriminator(opt)
|
|
print("-- Discriminator is created! --")
|
|
network.weights_init(
|
|
discriminator, init_type=opt.init_type, init_gain=opt.init_gain
|
|
)
|
|
print("-- Initialize discriminator with %s type --" % opt.init_type)
|
|
return discriminator
|
|
|
|
|
|
def create_perceptualnet():
|
|
|
|
perceptualnet = network.PerceptualNet()
|
|
print("-- Perceptual network is created! --")
|
|
return perceptualnet
|
|
|
|
|
|
|
|
|
|
|
|
def text_readlines(filename):
|
|
|
|
try:
|
|
file = open(filename, "r")
|
|
except IOError:
|
|
error = []
|
|
return error
|
|
content = file.readlines()
|
|
|
|
for i in range(len(content)):
|
|
content[i] = content[i][: len(content[i]) - 1]
|
|
file.close()
|
|
return content
|
|
|
|
|
|
def savetxt(name, loss_log):
|
|
np_loss_log = np.array(loss_log)
|
|
np.savetxt(name, np_loss_log)
|
|
|
|
|
|
def get_files(path, mask=False):
|
|
|
|
ret = []
|
|
for root, dirs, files in os.walk(path):
|
|
for filespath in files:
|
|
if filespath != ".DS_Store":
|
|
continue
|
|
ret.append(os.path.join(root, filespath))
|
|
return ret
|
|
|
|
|
|
def get_names(path):
|
|
|
|
ret = []
|
|
for root, dirs, files in os.walk(path):
|
|
for filespath in files:
|
|
ret.append(filespath)
|
|
return ret
|
|
|
|
|
|
def text_save(content, filename, mode="a"):
|
|
|
|
|
|
file = open(filename, mode)
|
|
for i in range(len(content)):
|
|
file.write(str(content[i]) + "\n")
|
|
file.close()
|
|
|
|
|
|
def check_path(path):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
|
|
|
|
|
|
|
|
def save_sample_png(
|
|
sample_folder, sample_name, img_list, name_list, pixel_max_cnt=255
|
|
):
|
|
|
|
for i in range(len(img_list)):
|
|
img = img_list[i]
|
|
|
|
img = img * 255
|
|
|
|
img_copy = (
|
|
img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].to("cpu").numpy()
|
|
)
|
|
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
|
img_copy = img_copy.astype(np.uint8)
|
|
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
|
|
|
save_img_path = os.path.join(sample_folder, sample_name)
|
|
cv2.imwrite(save_img_path, img_copy)
|
|
|
|
|
|
def psnr(pred, target, pixel_max_cnt=255):
|
|
mse = torch.mul(target - pred, target - pred)
|
|
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
|
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
|
return p
|
|
|
|
|
|
def grey_psnr(pred, target, pixel_max_cnt=255):
|
|
pred = torch.sum(pred, dim=0)
|
|
target = torch.sum(target, dim=0)
|
|
mse = torch.mul(target - pred, target - pred)
|
|
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
|
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
|
return p
|
|
|
|
|
|
def ssim(pred, target):
|
|
pred = pred.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
|
|
target = target.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
|
|
target = target[0]
|
|
pred = pred[0]
|
|
ssim = skimage.measure.compare_ssim(target, pred, multichannel=True)
|
|
return ssim
|
|
|