kailashahirwar commited on
Commit
bcf59c3
·
1 Parent(s): 7758ec5

preprocess garment added

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .DS_Store
3
+ input_image
4
+ output_image
5
+ cloth-mask
6
+ __pycache__
7
+ *.pyc
8
+ venv
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv()
3
+
4
+ import glob
5
+ import os
6
+ from PIL import Image
7
+
8
+ import gradio as gr
9
+ import preprocess
10
+ from huggingface_hub import login
11
+
12
+ def extract_garment(input_img, cls):
13
+ print(input_img, type(input_img), cls)
14
+
15
+ input_dir = "input_image"
16
+ output_dir = "output_image"
17
+
18
+ os.makedirs(input_dir, exist_ok=True)
19
+ os.makedirs(output_dir, exist_ok=True)
20
+
21
+ for f in glob.glob(input_dir + "/*.*"):
22
+ os.remove(f)
23
+
24
+ for f in glob.glob(output_dir + "/*.*"):
25
+ os.remove(f)
26
+
27
+ for f in glob.glob("cloth-mask/*.*"):
28
+ os.remove(f)
29
+
30
+ input_img.save(os.path.join(input_dir, "img.jpg"))
31
+
32
+ preprocess.extract_garment(inputs_dir=input_dir, outputs_dir=output_dir, cls=cls)
33
+
34
+ return Image.open(glob.glob(output_dir + "/*.*")[0])
35
+
36
+
37
+ css = """
38
+ #col-container {
39
+ margin: 0 auto;
40
+ max-width: 720px;
41
+ }
42
+ """
43
+
44
+ with gr.Blocks(css=css) as demo:
45
+ with gr.Column(elem_id="col-container"):
46
+ gr.Markdown(f"""
47
+ # Clothes Extraction using U2Net
48
+ Pull out clothes like tops, bottoms, and dresses from a photo. This implementation is based on the [U2Net](https://github.com/xuebinqin/U-2-Net) model.
49
+ """)
50
+
51
+ with gr.Row():
52
+ with gr.Column():
53
+ input_image = gr.Image(label="Input Image", type='pil', height="400px", show_label=True)
54
+ dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Extract garment",
55
+ info="Select the garment type you wish to extract!")
56
+
57
+ output_image = gr.Image(label="Extracted garment", type='pil', height="400px", show_label=True,
58
+ show_download_button=True)
59
+
60
+ with gr.Row():
61
+ submit_button = gr.Button("Submit", variant='primary', scale=1)
62
+ reset_button = gr.ClearButton(value="Reset", scale=1)
63
+
64
+ gr.on(
65
+ triggers=[submit_button.click],
66
+ fn=extract_garment,
67
+ inputs=[input_image, dropdown],
68
+ outputs=[output_image]
69
+ )
70
+
71
+ reset_button.click(
72
+ fn=lambda: (None, "upper", None),
73
+ inputs=[],
74
+ outputs=[input_image, dropdown, output_image],
75
+ concurrency_limit=1,
76
+ )
77
+
78
+ if __name__ == '__main__':
79
+ # login to hugging face
80
+ login(os.environ.get("HF_TOKEN"))
81
+
82
+ demo.launch(show_api=True)
preprocess/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .preprocess_garment import segment_garment, extract_garment
preprocess/load_u2net.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+
6
+ from .u2net_cloth_segm import U2NET
7
+
8
+
9
+ def load_cloth_segm_model(device, checkpoint_path, in_ch=3, out_ch=1):
10
+ if not os.path.exists(checkpoint_path):
11
+ print("Invalid path")
12
+ return
13
+
14
+ model = U2NET(in_ch=in_ch, out_ch=out_ch)
15
+
16
+ model_state_dict = torch.load(checkpoint_path, map_location=device)
17
+ new_state_dict = OrderedDict()
18
+ for k, v in model_state_dict.items():
19
+ name = k[7:] # remove `module.`
20
+ new_state_dict[name] = v
21
+
22
+ model.load_state_dict(new_state_dict)
23
+ model = model.to(device=device)
24
+
25
+ print("Checkpoints loaded from path: {}".format(checkpoint_path))
26
+
27
+ return model
preprocess/preprocess_garment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from tqdm import tqdm
11
+ import joblib
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ from .load_u2net import load_cloth_segm_model
15
+ from .utils import NormalizeImage, naive_cutout, resize_by_bigger_index, image_resize
16
+
17
+
18
+ def segment_garment(inputs_dir, outputs_dir, cls="all"):
19
+ os.makedirs(outputs_dir, exist_ok=True)
20
+
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+
23
+ transform_fn = transforms.Compose(
24
+ [transforms.ToTensor(),
25
+ NormalizeImage(0.5, 0.5)]
26
+ )
27
+
28
+ # load model from huggingface
29
+ file_path = hf_hub_download(repo_id="tryonlabs/u2net-cloth-segmentation", filename="u2net_cloth_segm.pth")
30
+
31
+ print("model loaded from huggingface:", file_path)
32
+
33
+ net = load_cloth_segm_model(device, file_path, in_ch=3, out_ch=4)
34
+
35
+ images_list = sorted(os.listdir(inputs_dir))
36
+ pbar = tqdm(total=len(images_list))
37
+
38
+ for image_name in images_list:
39
+ img = Image.open(os.path.join(inputs_dir, image_name)).convert('RGB')
40
+ img_size = img.size
41
+ img = img.resize((768, 768), Image.BICUBIC)
42
+ image_tensor = transform_fn(img)
43
+ image_tensor = torch.unsqueeze(image_tensor, 0)
44
+
45
+ with torch.no_grad():
46
+ output_tensor = net(image_tensor.to(device))
47
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
48
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
49
+ output_tensor = torch.squeeze(output_tensor, dim=0)
50
+ output_arr = output_tensor.cpu().numpy()
51
+
52
+ if cls == "all":
53
+ classes_to_save = []
54
+
55
+ # Check which classes are present in the image
56
+ for cls in range(1, 4): # Exclude background class (0)
57
+ if np.any(output_arr == cls):
58
+ classes_to_save.append(cls)
59
+ elif cls == "upper":
60
+ classes_to_save = [1]
61
+ elif cls == "lower":
62
+ classes_to_save = [2]
63
+ elif cls == "dress":
64
+ classes_to_save = [3]
65
+ else:
66
+ raise ValueError(f"Unknown cls: {cls}")
67
+
68
+ for cls1 in classes_to_save:
69
+ alpha_mask = (output_arr == cls1).astype(np.uint8) * 255
70
+ alpha_mask = alpha_mask[0] # Selecting the first channel to make it 2D
71
+ alpha_mask_img = Image.fromarray(alpha_mask, mode='L')
72
+ alpha_mask_img = alpha_mask_img.resize(img_size, Image.BICUBIC)
73
+ alpha_mask_img.save(os.path.join(outputs_dir, f'{image_name.split(".")[0]}_{cls1}.jpg'))
74
+
75
+ pbar.update(1)
76
+
77
+ pbar.close()
78
+
79
+
80
+ def extract_garment(inputs_dir, outputs_dir, cls="all", resize_to_width=None):
81
+ os.makedirs(outputs_dir, exist_ok=True)
82
+ cloth_mask_dir = os.path.join(Path(outputs_dir).parent.absolute(), "cloth-mask")
83
+ os.makedirs(cloth_mask_dir, exist_ok=True)
84
+
85
+ segment_garment(inputs_dir, os.path.join(Path(outputs_dir).parent.absolute(), "cloth-mask"), cls=cls)
86
+
87
+ images_path = sorted(glob.glob(os.path.join(inputs_dir, "*")))
88
+ masks_path = sorted(glob.glob(os.path.join(cloth_mask_dir, "*")))
89
+ img = Image.open(images_path[0])
90
+
91
+ for mask_path in masks_path:
92
+ mask = Image.open(mask_path)
93
+
94
+ cutout = np.array(naive_cutout(img, mask))
95
+ cutout = resize_by_bigger_index(cutout)
96
+
97
+ canvas = np.ones((1024, 768, 3), np.uint8) * 255
98
+ y1, y2 = (canvas.shape[0] - cutout.shape[0]) // 2, (canvas.shape[0] + cutout.shape[0]) // 2
99
+ x1, x2 = (canvas.shape[1] - cutout.shape[1]) // 2, (canvas.shape[1] + cutout.shape[1]) // 2
100
+
101
+ alpha_s = cutout[:, :, 3] / 255.0
102
+ alpha_l = 1.0 - alpha_s
103
+
104
+ for c in range(0, 3):
105
+ canvas[y1:y2, x1:x2, c] = (alpha_s * cutout[:, :, c] +
106
+ alpha_l * canvas[y1:y2, x1:x2, c])
107
+
108
+ # resize image before saving
109
+ if resize_to_width:
110
+ canvas = image_resize(canvas, width=resize_to_width)
111
+
112
+ canvas = Image.fromarray(canvas)
113
+
114
+ canvas.save(os.path.join(outputs_dir, f"{os.path.basename(mask_path).split('.')[0]}.jpg"), format='JPEG')
preprocess/u2net_cloth_segm.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+ hx = x
18
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
19
+
20
+ return xout
21
+
22
+
23
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
24
+ def _upsample_like(src, tar):
25
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
26
+
27
+ return src
28
+
29
+
30
+ ### RSU-7 ###
31
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
32
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
33
+ super(RSU7, self).__init__()
34
+
35
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
36
+
37
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
38
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
39
+
40
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
41
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
42
+
43
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
44
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
45
+
46
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
47
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
+
52
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
+
54
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
55
+
56
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
57
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
58
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
62
+
63
+ def forward(self, x):
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
87
+ hx6dup = _upsample_like(hx6d, hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
90
+ hx5dup = _upsample_like(hx5d, hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
93
+ hx4dup = _upsample_like(hx4d, hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
96
+ hx3dup = _upsample_like(hx3d, hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
99
+ hx2dup = _upsample_like(hx2d, hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
102
+
103
+ """
104
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
105
+ del hx6d, hx5d, hx3d, hx2d
106
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
107
+ """
108
+
109
+ return hx1d + hxin
110
+
111
+
112
+ ### RSU-6 ###
113
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
114
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
115
+ super(RSU6, self).__init__()
116
+
117
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
118
+
119
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
120
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
121
+
122
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
123
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+
133
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
134
+
135
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
136
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
137
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
138
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
140
+
141
+ def forward(self, x):
142
+ hx = x
143
+
144
+ hxin = self.rebnconvin(hx)
145
+
146
+ hx1 = self.rebnconv1(hxin)
147
+ hx = self.pool1(hx1)
148
+
149
+ hx2 = self.rebnconv2(hx)
150
+ hx = self.pool2(hx2)
151
+
152
+ hx3 = self.rebnconv3(hx)
153
+ hx = self.pool3(hx3)
154
+
155
+ hx4 = self.rebnconv4(hx)
156
+ hx = self.pool4(hx4)
157
+
158
+ hx5 = self.rebnconv5(hx)
159
+
160
+ hx6 = self.rebnconv6(hx5)
161
+
162
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
163
+ hx5dup = _upsample_like(hx5d, hx4)
164
+
165
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
166
+ hx4dup = _upsample_like(hx4d, hx3)
167
+
168
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
169
+ hx3dup = _upsample_like(hx3d, hx2)
170
+
171
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
172
+ hx2dup = _upsample_like(hx2d, hx1)
173
+
174
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
175
+
176
+ """
177
+ del hx1, hx2, hx3, hx4, hx5, hx6
178
+ del hx5d, hx4d, hx3d, hx2d
179
+ del hx2dup, hx3dup, hx4dup, hx5dup
180
+ """
181
+
182
+ return hx1d + hxin
183
+
184
+
185
+ ### RSU-5 ###
186
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
187
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
188
+ super(RSU5, self).__init__()
189
+
190
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
191
+
192
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
193
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
200
+
201
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
202
+
203
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
204
+
205
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
206
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
207
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
209
+
210
+ def forward(self, x):
211
+ hx = x
212
+
213
+ hxin = self.rebnconvin(hx)
214
+
215
+ hx1 = self.rebnconv1(hxin)
216
+ hx = self.pool1(hx1)
217
+
218
+ hx2 = self.rebnconv2(hx)
219
+ hx = self.pool2(hx2)
220
+
221
+ hx3 = self.rebnconv3(hx)
222
+ hx = self.pool3(hx3)
223
+
224
+ hx4 = self.rebnconv4(hx)
225
+
226
+ hx5 = self.rebnconv5(hx4)
227
+
228
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
229
+ hx4dup = _upsample_like(hx4d, hx3)
230
+
231
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
232
+ hx3dup = _upsample_like(hx3d, hx2)
233
+
234
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
235
+ hx2dup = _upsample_like(hx2d, hx1)
236
+
237
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
238
+
239
+ """
240
+ del hx1, hx2, hx3, hx4, hx5
241
+ del hx4d, hx3d, hx2d
242
+ del hx2dup, hx3dup, hx4dup
243
+ """
244
+
245
+ return hx1d + hxin
246
+
247
+
248
+ ### RSU-4 ###
249
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
250
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
251
+ super(RSU4, self).__init__()
252
+
253
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
254
+
255
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
256
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
257
+
258
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
259
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
260
+
261
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
262
+
263
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
264
+
265
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
266
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
267
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
268
+
269
+ def forward(self, x):
270
+ hx = x
271
+
272
+ hxin = self.rebnconvin(hx)
273
+
274
+ hx1 = self.rebnconv1(hxin)
275
+ hx = self.pool1(hx1)
276
+
277
+ hx2 = self.rebnconv2(hx)
278
+ hx = self.pool2(hx2)
279
+
280
+ hx3 = self.rebnconv3(hx)
281
+
282
+ hx4 = self.rebnconv4(hx3)
283
+
284
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
285
+ hx3dup = _upsample_like(hx3d, hx2)
286
+
287
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
288
+ hx2dup = _upsample_like(hx2d, hx1)
289
+
290
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
291
+
292
+ """
293
+ del hx1, hx2, hx3, hx4
294
+ del hx3d, hx2d
295
+ del hx2dup, hx3dup
296
+ """
297
+
298
+ return hx1d + hxin
299
+
300
+
301
+ ### RSU-4F ###
302
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
303
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
304
+ super(RSU4F, self).__init__()
305
+
306
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
307
+
308
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
309
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
310
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
311
+
312
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
313
+
314
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
315
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
316
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
317
+
318
+ def forward(self, x):
319
+ hx = x
320
+
321
+ hxin = self.rebnconvin(hx)
322
+
323
+ hx1 = self.rebnconv1(hxin)
324
+ hx2 = self.rebnconv2(hx1)
325
+ hx3 = self.rebnconv3(hx2)
326
+
327
+ hx4 = self.rebnconv4(hx3)
328
+
329
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
330
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
331
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
332
+
333
+ """
334
+ del hx1, hx2, hx3, hx4
335
+ del hx3d, hx2d
336
+ """
337
+
338
+ return hx1d + hxin
339
+
340
+
341
+ ##### U^2-Net ####
342
+ class U2NET(nn.Module):
343
+ def __init__(self, in_ch=3, out_ch=1):
344
+ super(U2NET, self).__init__()
345
+
346
+ self.stage1 = RSU7(in_ch, 32, 64)
347
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
348
+
349
+ self.stage2 = RSU6(64, 32, 128)
350
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
351
+
352
+ self.stage3 = RSU5(128, 64, 256)
353
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
354
+
355
+ self.stage4 = RSU4(256, 128, 512)
356
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
357
+
358
+ self.stage5 = RSU4F(512, 256, 512)
359
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
360
+
361
+ self.stage6 = RSU4F(512, 256, 512)
362
+
363
+ # decoder
364
+ self.stage5d = RSU4F(1024, 256, 512)
365
+ self.stage4d = RSU4(1024, 128, 256)
366
+ self.stage3d = RSU5(512, 64, 128)
367
+ self.stage2d = RSU6(256, 32, 64)
368
+ self.stage1d = RSU7(128, 16, 64)
369
+
370
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
371
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
372
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
373
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
374
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
375
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
376
+
377
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
378
+
379
+ def forward(self, x):
380
+ hx = x
381
+
382
+ # stage 1
383
+ hx1 = self.stage1(hx)
384
+ hx = self.pool12(hx1)
385
+
386
+ # stage 2
387
+ hx2 = self.stage2(hx)
388
+ hx = self.pool23(hx2)
389
+
390
+ # stage 3
391
+ hx3 = self.stage3(hx)
392
+ hx = self.pool34(hx3)
393
+
394
+ # stage 4
395
+ hx4 = self.stage4(hx)
396
+ hx = self.pool45(hx4)
397
+
398
+ # stage 5
399
+ hx5 = self.stage5(hx)
400
+ hx = self.pool56(hx5)
401
+
402
+ # stage 6
403
+ hx6 = self.stage6(hx)
404
+ hx6up = _upsample_like(hx6, hx5)
405
+
406
+ # -------------------- decoder --------------------
407
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
408
+ hx5dup = _upsample_like(hx5d, hx4)
409
+
410
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
411
+ hx4dup = _upsample_like(hx4d, hx3)
412
+
413
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
414
+ hx3dup = _upsample_like(hx3d, hx2)
415
+
416
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
417
+ hx2dup = _upsample_like(hx2d, hx1)
418
+
419
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
420
+
421
+ # side output
422
+ d1 = self.side1(hx1d)
423
+
424
+ d2 = self.side2(hx2d)
425
+ d2 = _upsample_like(d2, d1)
426
+
427
+ d3 = self.side3(hx3d)
428
+ d3 = _upsample_like(d3, d1)
429
+
430
+ d4 = self.side4(hx4d)
431
+ d4 = _upsample_like(d4, d1)
432
+
433
+ d5 = self.side5(hx5d)
434
+ d5 = _upsample_like(d5, d1)
435
+
436
+ d6 = self.side6(hx6)
437
+ d6 = _upsample_like(d6, d1)
438
+
439
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
440
+
441
+ """
442
+ del hx1, hx2, hx3, hx4, hx5, hx6
443
+ del hx5d, hx4d, hx3d, hx2d, hx1d
444
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
445
+ """
446
+
447
+ return d0, d1, d2, d3, d4, d5, d6
448
+
449
+
450
+ ### U^2-Net small ###
451
+ class U2NETP(nn.Module):
452
+ def __init__(self, in_ch=3, out_ch=1):
453
+ super(U2NETP, self).__init__()
454
+
455
+ self.stage1 = RSU7(in_ch, 16, 64)
456
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
457
+
458
+ self.stage2 = RSU6(64, 16, 64)
459
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
460
+
461
+ self.stage3 = RSU5(64, 16, 64)
462
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
463
+
464
+ self.stage4 = RSU4(64, 16, 64)
465
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
466
+
467
+ self.stage5 = RSU4F(64, 16, 64)
468
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
469
+
470
+ self.stage6 = RSU4F(64, 16, 64)
471
+
472
+ # decoder
473
+ self.stage5d = RSU4F(128, 16, 64)
474
+ self.stage4d = RSU4(128, 16, 64)
475
+ self.stage3d = RSU5(128, 16, 64)
476
+ self.stage2d = RSU6(128, 16, 64)
477
+ self.stage1d = RSU7(128, 16, 64)
478
+
479
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
480
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
481
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
482
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
483
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
484
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
485
+
486
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
487
+
488
+ def forward(self, x):
489
+ hx = x
490
+
491
+ # stage 1
492
+ hx1 = self.stage1(hx)
493
+ hx = self.pool12(hx1)
494
+
495
+ # stage 2
496
+ hx2 = self.stage2(hx)
497
+ hx = self.pool23(hx2)
498
+
499
+ # stage 3
500
+ hx3 = self.stage3(hx)
501
+ hx = self.pool34(hx3)
502
+
503
+ # stage 4
504
+ hx4 = self.stage4(hx)
505
+ hx = self.pool45(hx4)
506
+
507
+ # stage 5
508
+ hx5 = self.stage5(hx)
509
+ hx = self.pool56(hx5)
510
+
511
+ # stage 6
512
+ hx6 = self.stage6(hx)
513
+ hx6up = _upsample_like(hx6, hx5)
514
+
515
+ # decoder
516
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
517
+ hx5dup = _upsample_like(hx5d, hx4)
518
+
519
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
520
+ hx4dup = _upsample_like(hx4d, hx3)
521
+
522
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
523
+ hx3dup = _upsample_like(hx3d, hx2)
524
+
525
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
526
+ hx2dup = _upsample_like(hx2d, hx1)
527
+
528
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
529
+
530
+ # side output
531
+ d1 = self.side1(hx1d)
532
+
533
+ d2 = self.side2(hx2d)
534
+ d2 = _upsample_like(d2, d1)
535
+
536
+ d3 = self.side3(hx3d)
537
+ d3 = _upsample_like(d3, d1)
538
+
539
+ d4 = self.side4(hx4d)
540
+ d4 = _upsample_like(d4, d1)
541
+
542
+ d5 = self.side5(hx5d)
543
+ d5 = _upsample_like(d5, d1)
544
+
545
+ d6 = self.side6(hx6)
546
+ d6 = _upsample_like(d6, d1)
547
+
548
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
549
+
550
+ return d0, d1, d2, d3, d4, d5, d6
preprocess/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+
9
+ class NormalizeImage(object):
10
+ """Normalize given tensor into given mean and standard dev
11
+
12
+ Args:
13
+ mean (float): Desired mean to substract from tensors
14
+ std (float): Desired std to divide from tensors
15
+ """
16
+
17
+ def __init__(self, mean, std):
18
+ assert isinstance(mean, (float))
19
+ if isinstance(mean, float):
20
+ self.mean = mean
21
+
22
+ if isinstance(std, float):
23
+ self.std = std
24
+
25
+ self.normalize_1 = transforms.Normalize(self.mean, self.std)
26
+ self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
27
+ self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
28
+
29
+ def __call__(self, image_tensor):
30
+ if image_tensor.shape[0] == 1:
31
+ return self.normalize_1(image_tensor)
32
+
33
+ elif image_tensor.shape[0] == 3:
34
+ return self.normalize_3(image_tensor)
35
+
36
+ elif image_tensor.shape[0] == 18:
37
+ return self.normalize_18(image_tensor)
38
+
39
+ else:
40
+ assert "Please set proper channels! Normalization implemented only for 1, 3 and 18"
41
+
42
+
43
+ def naive_cutout(img, mask):
44
+ empty = Image.new("RGBA", (img.size), 0)
45
+ cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
46
+ return cutout
47
+
48
+
49
+ def resize_by_bigger_index(crop):
50
+ # function resizes and keeps the aspect ratio same
51
+ crop_shape = crop.shape # hxwxc
52
+ if crop_shape[0] / crop_shape[1] <= 1.33:
53
+ resized_crop = image_resize(crop, width=768)
54
+ else:
55
+ resized_crop = image_resize(crop, height=1024)
56
+ return resized_crop
57
+
58
+
59
+ def image_resize(image, width=None, height=None):
60
+ dim = None
61
+ (h, w) = image.shape[:2]
62
+
63
+ if width is None and height is None:
64
+ return image
65
+
66
+ if width is None:
67
+ r = height / float(h)
68
+ dim = (int(w * r), height)
69
+
70
+ else:
71
+ r = width / float(w)
72
+ dim = (width, int(h * r))
73
+
74
+ resized = cv2.resize(image, dim)
75
+
76
+ return resized
77
+
78
+
79
+ def convert_to_jpg(image_path, output_dir, size=None):
80
+ """
81
+ Convert image to jpg format
82
+ :param image_path: image path
83
+ :param output_dir: output directory
84
+ :param size: desired size of the image (w, h)
85
+ """
86
+ img = cv2.imread(image_path)
87
+ if size is not None:
88
+ img = image_resize(img, width=size[0], height=size[1])
89
+
90
+ filename = Path(image_path).name
91
+ cv2.imwrite(os.path.join(output_dir, filename.split(".")[0] + ".jpg"), img)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pillow
2
+ gradio==4.44.1
3
+ torch
4
+ torchvision
5
+ numpy==1.26.1
6
+ tqdm
7
+ opencv-python
8
+ joblib
9
+ huggingface-hub==0.25.0
10
+ python-dotenv