import os import torch import torch.nn as nn from torch.utils.data import DataLoader from types import SimpleNamespace from deepfillv2 import test_dataset, utils from config import * class InpaintingTester: def __init__(self, save_path, resize_to=None): if resize_to is None: resize_to = RESIZE_TO self.save_path = save_path self.setsize = resize_to # Build the generator network opt = SimpleNamespace( pad_type=PAD_TYPE, in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, latent_channels=LATENT_CHANNELS, activation=ACTIVATION, norm=NORM, init_type=INIT_TYPE, init_gain=INIT_GAIN, use_cuda=CUDA, gpu_device=GPU_DEVICE, ) # Initialize generator (only once) self.generator = utils.create_generator(opt).eval() # Load pretrained model weights # print("-- INPAINT: Loading Pretrained Model --") self.load_model_generator(self.generator) # Move the generator to GPU self.generator = self.generator.to(GPU_DEVICE) def load_model_generator(self, generator): pretrained_dict = torch.load( DEEPFILL_MODEL_PATH, map_location=torch.device(GPU_DEVICE), weights_only=True ) generator.load_state_dict(pretrained_dict) def process_image(self, in_image, mask_image, save_image_path): # Initialize dataset and dataloader trainset = test_dataset.InpaintDataset(in_image, mask_image, self.setsize) dataloader = DataLoader( trainset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, ) # Testing loop for a single image for batch_idx, (img, mask) in enumerate(dataloader): img = img.to(GPU_DEVICE) mask = mask.to(GPU_DEVICE) # Generator output with torch.no_grad(): first_out, second_out = self.generator(img, mask) # Combine outputs with input first_out_wholeimg = img * (1 - mask) + first_out * mask second_out_wholeimg = img * (1 - mask) + second_out * mask masked_img = img * (1 - mask) + mask mask = torch.cat((mask, mask, mask), 1) img_list = [second_out_wholeimg] name_list = ["second_out"] # Save the sample image results_path = os.path.dirname(save_image_path) if not os.path.exists(results_path): os.makedirs(results_path) utils.save_sample_png( sample_folder=results_path, sample_name=os.path.basename(save_image_path), img_list=img_list, name_list=name_list, pixel_max_cnt=255, ) def process_multiple_images(self, image_mask_pairs): # Iterate through a list of image/mask pairs and save results png_images=[] for img_path, mask_path in image_mask_pairs: try: save_image_path = os.path.join(self.save_path, os.path.basename(img_path)) print(f"Processing: {img_path} and {mask_path}") self.process_image(img_path, mask_path, save_image_path) extention = os.path.splitext(save_image_path)[1] save_at=save_image_path.replace(extention, ".png") png_images.append(save_at) except Exception as e: if self.save_path in png_images: png_images.pop() png_images.append(None) print(f"Error: {e}") # print("-- All Inpainting is finished --") return png_images # Main execution # if __name__ == "__main__": # save_path = "./output" # resize_to = None # Default size from config # # List of image and mask pairs # image_mask_pairs = [ # ( "./input/image.jpg", "./input/mask.jpg"), # ] # tester = InpaintingTester(save_path, resize_to) # # Process multiple images using a loop # results=tester.process_multiple_images(image_mask_pairs) # print(results)