NeuralFalcon's picture
Upload 4 files
205273a verified
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from types import SimpleNamespace
from deepfillv2 import test_dataset, utils
from config import *
class InpaintingTester:
def __init__(self, save_path, resize_to=None):
if resize_to is None:
resize_to = RESIZE_TO
self.save_path = save_path
self.setsize = resize_to
# Build the generator network
opt = SimpleNamespace(
pad_type=PAD_TYPE,
in_channels=IN_CHANNELS,
out_channels=OUT_CHANNELS,
latent_channels=LATENT_CHANNELS,
activation=ACTIVATION,
norm=NORM,
init_type=INIT_TYPE,
init_gain=INIT_GAIN,
use_cuda=CUDA,
gpu_device=GPU_DEVICE,
)
# Initialize generator (only once)
self.generator = utils.create_generator(opt).eval()
# Load pretrained model weights
# print("-- INPAINT: Loading Pretrained Model --")
self.load_model_generator(self.generator)
# Move the generator to GPU
self.generator = self.generator.to(GPU_DEVICE)
def load_model_generator(self, generator):
pretrained_dict = torch.load(
DEEPFILL_MODEL_PATH, map_location=torch.device(GPU_DEVICE), weights_only=True
)
generator.load_state_dict(pretrained_dict)
def process_image(self, in_image, mask_image, save_image_path):
# Initialize dataset and dataloader
trainset = test_dataset.InpaintDataset(in_image, mask_image, self.setsize)
dataloader = DataLoader(
trainset,
batch_size=1,
shuffle=False,
num_workers=8,
pin_memory=True,
)
# Testing loop for a single image
for batch_idx, (img, mask) in enumerate(dataloader):
img = img.to(GPU_DEVICE)
mask = mask.to(GPU_DEVICE)
# Generator output
with torch.no_grad():
first_out, second_out = self.generator(img, mask)
# Combine outputs with input
first_out_wholeimg = img * (1 - mask) + first_out * mask
second_out_wholeimg = img * (1 - mask) + second_out * mask
masked_img = img * (1 - mask) + mask
mask = torch.cat((mask, mask, mask), 1)
img_list = [second_out_wholeimg]
name_list = ["second_out"]
# Save the sample image
results_path = os.path.dirname(save_image_path)
if not os.path.exists(results_path):
os.makedirs(results_path)
utils.save_sample_png(
sample_folder=results_path,
sample_name=os.path.basename(save_image_path),
img_list=img_list,
name_list=name_list,
pixel_max_cnt=255,
)
def process_multiple_images(self, image_mask_pairs):
# Iterate through a list of image/mask pairs and save results
png_images=[]
for img_path, mask_path in image_mask_pairs:
try:
save_image_path = os.path.join(self.save_path, os.path.basename(img_path))
print(f"Processing: {img_path} and {mask_path}")
self.process_image(img_path, mask_path, save_image_path)
extention = os.path.splitext(save_image_path)[1]
save_at=save_image_path.replace(extention, ".png")
png_images.append(save_at)
except Exception as e:
if self.save_path in png_images:
png_images.pop()
png_images.append(None)
print(f"Error: {e}")
# print("-- All Inpainting is finished --")
return png_images
# Main execution
# if __name__ == "__main__":
# save_path = "./output"
# resize_to = None # Default size from config
# # List of image and mask pairs
# image_mask_pairs = [
# ( "./input/image.jpg", "./input/mask.jpg"),
# ]
# tester = InpaintingTester(save_path, resize_to)
# # Process multiple images using a loop
# results=tester.process_multiple_images(image_mask_pairs)
# print(results)