File size: 4,362 Bytes
205273a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from types import SimpleNamespace
from deepfillv2 import test_dataset, utils
from config import *

class InpaintingTester:
    def __init__(self, save_path, resize_to=None):
        if resize_to is None:
            resize_to = RESIZE_TO
        self.save_path = save_path
        self.setsize = resize_to

        # Build the generator network
        opt = SimpleNamespace(
            pad_type=PAD_TYPE,
            in_channels=IN_CHANNELS,
            out_channels=OUT_CHANNELS,
            latent_channels=LATENT_CHANNELS,
            activation=ACTIVATION,
            norm=NORM,
            init_type=INIT_TYPE,
            init_gain=INIT_GAIN,
            use_cuda=CUDA,
            gpu_device=GPU_DEVICE,
        )

        # Initialize generator (only once)
        self.generator = utils.create_generator(opt).eval()

        # Load pretrained model weights
        # print("-- INPAINT: Loading Pretrained Model --")
        self.load_model_generator(self.generator)

        # Move the generator to GPU
        self.generator = self.generator.to(GPU_DEVICE)

    def load_model_generator(self, generator):
        pretrained_dict = torch.load(
            DEEPFILL_MODEL_PATH, map_location=torch.device(GPU_DEVICE), weights_only=True
        )
        generator.load_state_dict(pretrained_dict)

    def process_image(self, in_image, mask_image, save_image_path):
        # Initialize dataset and dataloader
        trainset = test_dataset.InpaintDataset(in_image, mask_image, self.setsize)
        dataloader = DataLoader(
            trainset,
            batch_size=1,
            shuffle=False,
            num_workers=8,
            pin_memory=True,
        )

        # Testing loop for a single image
        for batch_idx, (img, mask) in enumerate(dataloader):
            img = img.to(GPU_DEVICE)
            mask = mask.to(GPU_DEVICE)

            # Generator output
            with torch.no_grad():
                first_out, second_out = self.generator(img, mask)

            # Combine outputs with input
            first_out_wholeimg = img * (1 - mask) + first_out * mask
            second_out_wholeimg = img * (1 - mask) + second_out * mask

            masked_img = img * (1 - mask) + mask
            mask = torch.cat((mask, mask, mask), 1)
            img_list = [second_out_wholeimg]
            name_list = ["second_out"]

            # Save the sample image
            results_path = os.path.dirname(save_image_path)
            if not os.path.exists(results_path):
                os.makedirs(results_path)

            utils.save_sample_png(
                sample_folder=results_path,
                sample_name=os.path.basename(save_image_path),
                img_list=img_list,
                name_list=name_list,
                pixel_max_cnt=255,
            )

    def process_multiple_images(self, image_mask_pairs):
        # Iterate through a list of image/mask pairs and save results
        png_images=[]
        for img_path, mask_path in image_mask_pairs:
            try:
                save_image_path = os.path.join(self.save_path, os.path.basename(img_path))
                print(f"Processing: {img_path} and {mask_path}")
                self.process_image(img_path, mask_path, save_image_path)
                extention = os.path.splitext(save_image_path)[1]
                save_at=save_image_path.replace(extention, ".png")
                png_images.append(save_at)
            except Exception as e:
                if self.save_path in png_images:
                    png_images.pop()
                    png_images.append(None)
                print(f"Error: {e}")
        # print("-- All Inpainting is finished --")
        return png_images

# Main execution
# if __name__ == "__main__":
#     save_path = "./output"
#     resize_to = None  # Default size from config

#     # List of image and mask pairs
#     image_mask_pairs = [
#         ( "./input/image.jpg", "./input/mask.jpg"),
#     ]

#     tester = InpaintingTester(save_path, resize_to)
    
#     # Process multiple images using a loop
#     results=tester.process_multiple_images(image_mask_pairs)
#     print(results)