|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from types import SimpleNamespace
|
|
from deepfillv2 import test_dataset, utils
|
|
from config import *
|
|
|
|
class InpaintingTester:
|
|
def __init__(self, save_path, resize_to=None):
|
|
if resize_to is None:
|
|
resize_to = RESIZE_TO
|
|
self.save_path = save_path
|
|
self.setsize = resize_to
|
|
|
|
|
|
opt = SimpleNamespace(
|
|
pad_type=PAD_TYPE,
|
|
in_channels=IN_CHANNELS,
|
|
out_channels=OUT_CHANNELS,
|
|
latent_channels=LATENT_CHANNELS,
|
|
activation=ACTIVATION,
|
|
norm=NORM,
|
|
init_type=INIT_TYPE,
|
|
init_gain=INIT_GAIN,
|
|
use_cuda=CUDA,
|
|
gpu_device=GPU_DEVICE,
|
|
)
|
|
|
|
|
|
self.generator = utils.create_generator(opt).eval()
|
|
|
|
|
|
|
|
self.load_model_generator(self.generator)
|
|
|
|
|
|
self.generator = self.generator.to(GPU_DEVICE)
|
|
|
|
def load_model_generator(self, generator):
|
|
pretrained_dict = torch.load(
|
|
DEEPFILL_MODEL_PATH, map_location=torch.device(GPU_DEVICE), weights_only=True
|
|
)
|
|
generator.load_state_dict(pretrained_dict)
|
|
|
|
def process_image(self, in_image, mask_image, save_image_path):
|
|
|
|
trainset = test_dataset.InpaintDataset(in_image, mask_image, self.setsize)
|
|
dataloader = DataLoader(
|
|
trainset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
num_workers=8,
|
|
pin_memory=True,
|
|
)
|
|
|
|
|
|
for batch_idx, (img, mask) in enumerate(dataloader):
|
|
img = img.to(GPU_DEVICE)
|
|
mask = mask.to(GPU_DEVICE)
|
|
|
|
|
|
with torch.no_grad():
|
|
first_out, second_out = self.generator(img, mask)
|
|
|
|
|
|
first_out_wholeimg = img * (1 - mask) + first_out * mask
|
|
second_out_wholeimg = img * (1 - mask) + second_out * mask
|
|
|
|
masked_img = img * (1 - mask) + mask
|
|
mask = torch.cat((mask, mask, mask), 1)
|
|
img_list = [second_out_wholeimg]
|
|
name_list = ["second_out"]
|
|
|
|
|
|
results_path = os.path.dirname(save_image_path)
|
|
if not os.path.exists(results_path):
|
|
os.makedirs(results_path)
|
|
|
|
utils.save_sample_png(
|
|
sample_folder=results_path,
|
|
sample_name=os.path.basename(save_image_path),
|
|
img_list=img_list,
|
|
name_list=name_list,
|
|
pixel_max_cnt=255,
|
|
)
|
|
|
|
def process_multiple_images(self, image_mask_pairs):
|
|
|
|
png_images=[]
|
|
for img_path, mask_path in image_mask_pairs:
|
|
try:
|
|
save_image_path = os.path.join(self.save_path, os.path.basename(img_path))
|
|
print(f"Processing: {img_path} and {mask_path}")
|
|
self.process_image(img_path, mask_path, save_image_path)
|
|
extention = os.path.splitext(save_image_path)[1]
|
|
save_at=save_image_path.replace(extention, ".png")
|
|
png_images.append(save_at)
|
|
except Exception as e:
|
|
if self.save_path in png_images:
|
|
png_images.pop()
|
|
png_images.append(None)
|
|
print(f"Error: {e}")
|
|
|
|
return png_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|