Spaces:
Runtime error
Runtime error
add seg
Browse files- app.py +4 -3
- demo/demos.py +25 -0
- demo/model.py +102 -1
- requirements.txt +2 -1
- seger.py +283 -0
app.py
CHANGED
@@ -8,7 +8,7 @@ os.system('mim install mmcv-full==1.7.0')
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
-
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
@@ -22,6 +22,7 @@ urls = {
|
|
22 |
urls_mmpose = [
|
23 |
'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
|
24 |
'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
|
|
|
25 |
]
|
26 |
if os.path.exists('models') == False:
|
27 |
os.mkdir('models')
|
@@ -69,7 +70,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
69 |
create_demo_sketch(model.process_sketch)
|
70 |
with gr.TabItem('Draw'):
|
71 |
create_demo_draw(model.process_draw)
|
|
|
|
|
72 |
|
73 |
-
# demo.queue(api_open=False).launch(server_name='0.0.0.0')
|
74 |
-
# demo.queue(show_api=False, enable_queue=False).launch(server_name='0.0.0.0')
|
75 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
+
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
|
|
22 |
urls_mmpose = [
|
23 |
'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
|
24 |
'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
|
25 |
+
'https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth'
|
26 |
]
|
27 |
if os.path.exists('models') == False:
|
28 |
os.mkdir('models')
|
|
|
70 |
create_demo_sketch(model.process_sketch)
|
71 |
with gr.TabItem('Draw'):
|
72 |
create_demo_draw(model.process_draw)
|
73 |
+
with gr.TabItem('Segmentation'):
|
74 |
+
create_demo_seg(model.process_seg)
|
75 |
|
|
|
|
|
76 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
demo/demos.py
CHANGED
@@ -70,6 +70,31 @@ def create_demo_sketch(process):
|
|
70 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
71 |
return demo
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def create_demo_draw(process):
|
74 |
with gr.Blocks() as demo:
|
75 |
with gr.Row():
|
|
|
70 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
71 |
return demo
|
72 |
|
73 |
+
def create_demo_seg(process):
|
74 |
+
with gr.Blocks() as demo:
|
75 |
+
with gr.Row():
|
76 |
+
gr.Markdown('## T2I-Adapter (Segmentation)')
|
77 |
+
with gr.Row():
|
78 |
+
with gr.Column():
|
79 |
+
input_img = gr.Image(source='upload', type="numpy")
|
80 |
+
prompt = gr.Textbox(label="Prompt")
|
81 |
+
neg_prompt = gr.Textbox(label="Negative Prompt",
|
82 |
+
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
83 |
+
pos_prompt = gr.Textbox(label="Positive Prompt",
|
84 |
+
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
85 |
+
with gr.Row():
|
86 |
+
type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
|
87 |
+
run_button = gr.Button(label="Run")
|
88 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
|
89 |
+
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
90 |
+
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
91 |
+
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
92 |
+
with gr.Column():
|
93 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
94 |
+
ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
95 |
+
run_button.click(fn=process, inputs=ips, outputs=[result])
|
96 |
+
return demo
|
97 |
+
|
98 |
def create_demo_draw(process):
|
99 |
with gr.Blocks() as demo:
|
100 |
with gr.Row():
|
demo/model.py
CHANGED
@@ -13,7 +13,30 @@ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process
|
|
13 |
import os
|
14 |
import cv2
|
15 |
import numpy as np
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def imshow_keypoints(img,
|
19 |
pose_result,
|
@@ -118,6 +141,13 @@ class Model_all:
|
|
118 |
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()})
|
119 |
self.model_edge.to(device)
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
# keypose part
|
122 |
self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
123 |
use_conv=False).to(device)
|
@@ -218,6 +248,77 @@ class Model_all:
|
|
218 |
|
219 |
return [im_edge, x_samples_ddim]
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
@torch.no_grad()
|
222 |
def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
223 |
if self.current_base != base_model:
|
|
|
13 |
import os
|
14 |
import cv2
|
15 |
import numpy as np
|
16 |
+
from seger import seger, Colorize
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
def preprocessing(image, device):
|
20 |
+
# Resize
|
21 |
+
scale = 640 / max(image.shape[:2])
|
22 |
+
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
23 |
+
raw_image = image.astype(np.uint8)
|
24 |
+
|
25 |
+
# Subtract mean values
|
26 |
+
image = image.astype(np.float32)
|
27 |
+
image -= np.array(
|
28 |
+
[
|
29 |
+
float(104.008),
|
30 |
+
float(116.669),
|
31 |
+
float(122.675),
|
32 |
+
]
|
33 |
+
)
|
34 |
+
|
35 |
+
# Convert to torch.Tensor and add "batch" axis
|
36 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
|
37 |
+
image = image.to(device)
|
38 |
+
|
39 |
+
return image, raw_image
|
40 |
|
41 |
def imshow_keypoints(img,
|
42 |
pose_result,
|
|
|
141 |
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()})
|
142 |
self.model_edge.to(device)
|
143 |
|
144 |
+
# segmentation part
|
145 |
+
self.model_seger = seger().to(device)
|
146 |
+
self.model_seger.eval()
|
147 |
+
self.coler = Colorize(n=182)
|
148 |
+
self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
|
149 |
+
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
150 |
+
|
151 |
# keypose part
|
152 |
self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
153 |
use_conv=False).to(device)
|
|
|
248 |
|
249 |
return [im_edge, x_samples_ddim]
|
250 |
|
251 |
+
@torch.no_grad()
|
252 |
+
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
253 |
+
con_strength, base_model):
|
254 |
+
if self.current_base != base_model:
|
255 |
+
ckpt = os.path.join("models", base_model)
|
256 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
257 |
+
if "state_dict" in pl_sd:
|
258 |
+
sd = pl_sd["state_dict"]
|
259 |
+
else:
|
260 |
+
sd = pl_sd
|
261 |
+
self.base_model.load_state_dict(sd, strict=False)
|
262 |
+
self.current_base = base_model
|
263 |
+
if 'anything' in base_model.lower():
|
264 |
+
self.load_vae()
|
265 |
+
|
266 |
+
con_strength = int((1 - con_strength) * 50)
|
267 |
+
if fix_sample == 'True':
|
268 |
+
seed_everything(42)
|
269 |
+
im = cv2.resize(input_img, (512, 512))
|
270 |
+
|
271 |
+
if type_in == 'Segmentation':
|
272 |
+
im_seg = im.copy()
|
273 |
+
im = img2tensor(im).unsqueeze(0) / 255.
|
274 |
+
labelmap = im.float()
|
275 |
+
elif type_in == 'Image':
|
276 |
+
im, _ = preprocessing(im, self.device)
|
277 |
+
_, _, H, W = im.shape
|
278 |
+
|
279 |
+
# Image -> Probability map
|
280 |
+
logits = self.model_seger(im)
|
281 |
+
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
|
282 |
+
probs = F.softmax(logits, dim=1)[0]
|
283 |
+
probs = probs.cpu().data.numpy()
|
284 |
+
labelmap = np.argmax(probs, axis=0)
|
285 |
+
|
286 |
+
labelmap = self.coler(labelmap)
|
287 |
+
labelmap = np.transpose(labelmap, (1,2,0))
|
288 |
+
labelmap = cv2.resize(labelmap, (512, 512))
|
289 |
+
labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True)/255.
|
290 |
+
im_seg = tensor2img(labelmap)
|
291 |
+
labelmap = labelmap.unsqueeze(0)
|
292 |
+
|
293 |
+
# extract condition features
|
294 |
+
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
295 |
+
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
296 |
+
features_adapter = self.model_seg(labelmap.to(self.device))
|
297 |
+
shape = [4, 64, 64]
|
298 |
+
|
299 |
+
# sampling
|
300 |
+
samples_ddim, _ = self.sampler.sample(S=50,
|
301 |
+
conditioning=c,
|
302 |
+
batch_size=1,
|
303 |
+
shape=shape,
|
304 |
+
verbose=False,
|
305 |
+
unconditional_guidance_scale=scale,
|
306 |
+
unconditional_conditioning=nc,
|
307 |
+
eta=0.0,
|
308 |
+
x_T=None,
|
309 |
+
features_adapter1=features_adapter,
|
310 |
+
mode='sketch',
|
311 |
+
con_strength=con_strength)
|
312 |
+
|
313 |
+
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
314 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
315 |
+
x_samples_ddim = x_samples_ddim.to('cpu')
|
316 |
+
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
317 |
+
x_samples_ddim = 255. * x_samples_ddim
|
318 |
+
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
319 |
+
|
320 |
+
return [im_seg, x_samples_ddim]
|
321 |
+
|
322 |
@torch.no_grad()
|
323 |
def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
324 |
if self.current_base != base_model:
|
requirements.txt
CHANGED
@@ -15,4 +15,5 @@ kornia==0.6.8
|
|
15 |
openmim
|
16 |
mmpose
|
17 |
mmdet
|
18 |
-
psutil
|
|
|
|
15 |
openmim
|
16 |
mmpose
|
17 |
mmdet
|
18 |
+
psutil
|
19 |
+
blobfile
|
seger.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import cv2
|
6 |
+
from basicsr.utils import img2tensor, tensor2img
|
7 |
+
|
8 |
+
_BATCH_NORM = nn.BatchNorm2d
|
9 |
+
_BOTTLENECK_EXPANSION = 4
|
10 |
+
|
11 |
+
import blobfile as bf
|
12 |
+
|
13 |
+
def _list_image_files_recursively(data_dir):
|
14 |
+
results = []
|
15 |
+
for entry in sorted(bf.listdir(data_dir)):
|
16 |
+
full_path = bf.join(data_dir, entry)
|
17 |
+
ext = entry.split(".")[-1]
|
18 |
+
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
|
19 |
+
results.append(full_path)
|
20 |
+
elif bf.isdir(full_path):
|
21 |
+
results.extend(_list_image_files_recursively(full_path))
|
22 |
+
return results
|
23 |
+
|
24 |
+
def uint82bin(n, count=8):
|
25 |
+
"""returns the binary of integer n, count refers to amount of bits"""
|
26 |
+
return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
|
27 |
+
|
28 |
+
|
29 |
+
def labelcolormap(N):
|
30 |
+
if N == 35: # cityscape
|
31 |
+
cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81),
|
32 |
+
(128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153),
|
33 |
+
(180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
|
34 |
+
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
|
35 |
+
(0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)],
|
36 |
+
dtype=np.uint8)
|
37 |
+
else:
|
38 |
+
cmap = np.zeros((N, 3), dtype=np.uint8)
|
39 |
+
for i in range(N):
|
40 |
+
r, g, b = 0, 0, 0
|
41 |
+
id = i + 1 # let's give 0 a color
|
42 |
+
for j in range(7):
|
43 |
+
str_id = uint82bin(id)
|
44 |
+
r = r ^ (np.uint8(str_id[-1]) << (7 - j))
|
45 |
+
g = g ^ (np.uint8(str_id[-2]) << (7 - j))
|
46 |
+
b = b ^ (np.uint8(str_id[-3]) << (7 - j))
|
47 |
+
id = id >> 3
|
48 |
+
cmap[i, 0] = r
|
49 |
+
cmap[i, 1] = g
|
50 |
+
cmap[i, 2] = b
|
51 |
+
|
52 |
+
return cmap
|
53 |
+
|
54 |
+
|
55 |
+
class Colorize(object):
|
56 |
+
def __init__(self, n=182):
|
57 |
+
self.cmap = labelcolormap(n)
|
58 |
+
|
59 |
+
def __call__(self, gray_image):
|
60 |
+
size = gray_image.shape
|
61 |
+
color_image = np.zeros((3, size[0], size[1]))
|
62 |
+
|
63 |
+
for label in range(0, len(self.cmap)):
|
64 |
+
mask = (label == gray_image )
|
65 |
+
color_image[0][mask] = self.cmap[label][0]
|
66 |
+
color_image[1][mask] = self.cmap[label][1]
|
67 |
+
color_image[2][mask] = self.cmap[label][2]
|
68 |
+
|
69 |
+
return color_image
|
70 |
+
|
71 |
+
class _ConvBnReLU(nn.Sequential):
|
72 |
+
"""
|
73 |
+
Cascade of 2D convolution, batch norm, and ReLU.
|
74 |
+
"""
|
75 |
+
|
76 |
+
BATCH_NORM = _BATCH_NORM
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True
|
80 |
+
):
|
81 |
+
super(_ConvBnReLU, self).__init__()
|
82 |
+
self.add_module(
|
83 |
+
"conv",
|
84 |
+
nn.Conv2d(
|
85 |
+
in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False
|
86 |
+
),
|
87 |
+
)
|
88 |
+
self.add_module("bn", _BATCH_NORM(out_ch, eps=1e-5, momentum=1 - 0.999))
|
89 |
+
|
90 |
+
if relu:
|
91 |
+
self.add_module("relu", nn.ReLU())
|
92 |
+
|
93 |
+
class _Bottleneck(nn.Module):
|
94 |
+
"""
|
95 |
+
Bottleneck block of MSRA ResNet.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, in_ch, out_ch, stride, dilation, downsample):
|
99 |
+
super(_Bottleneck, self).__init__()
|
100 |
+
mid_ch = out_ch // _BOTTLENECK_EXPANSION
|
101 |
+
self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True)
|
102 |
+
self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True)
|
103 |
+
self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False)
|
104 |
+
self.shortcut = (
|
105 |
+
_ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False)
|
106 |
+
if downsample
|
107 |
+
else nn.Identity()
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
h = self.reduce(x)
|
112 |
+
h = self.conv3x3(h)
|
113 |
+
h = self.increase(h)
|
114 |
+
h += self.shortcut(x)
|
115 |
+
return F.relu(h)
|
116 |
+
|
117 |
+
class _ResLayer(nn.Sequential):
|
118 |
+
"""
|
119 |
+
Residual layer with multi grids
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None):
|
123 |
+
super(_ResLayer, self).__init__()
|
124 |
+
|
125 |
+
if multi_grids is None:
|
126 |
+
multi_grids = [1 for _ in range(n_layers)]
|
127 |
+
else:
|
128 |
+
assert n_layers == len(multi_grids)
|
129 |
+
|
130 |
+
# Downsampling is only in the first block
|
131 |
+
for i in range(n_layers):
|
132 |
+
self.add_module(
|
133 |
+
"block{}".format(i + 1),
|
134 |
+
_Bottleneck(
|
135 |
+
in_ch=(in_ch if i == 0 else out_ch),
|
136 |
+
out_ch=out_ch,
|
137 |
+
stride=(stride if i == 0 else 1),
|
138 |
+
dilation=dilation * multi_grids[i],
|
139 |
+
downsample=(True if i == 0 else False),
|
140 |
+
),
|
141 |
+
)
|
142 |
+
|
143 |
+
class _Stem(nn.Sequential):
|
144 |
+
"""
|
145 |
+
The 1st conv layer.
|
146 |
+
Note that the max pooling is different from both MSRA and FAIR ResNet.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, out_ch):
|
150 |
+
super(_Stem, self).__init__()
|
151 |
+
self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1))
|
152 |
+
self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True))
|
153 |
+
|
154 |
+
class _ASPP(nn.Module):
|
155 |
+
"""
|
156 |
+
Atrous spatial pyramid pooling (ASPP)
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, in_ch, out_ch, rates):
|
160 |
+
super(_ASPP, self).__init__()
|
161 |
+
for i, rate in enumerate(rates):
|
162 |
+
self.add_module(
|
163 |
+
"c{}".format(i),
|
164 |
+
nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True),
|
165 |
+
)
|
166 |
+
|
167 |
+
for m in self.children():
|
168 |
+
nn.init.normal_(m.weight, mean=0, std=0.01)
|
169 |
+
nn.init.constant_(m.bias, 0)
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
return sum([stage(x) for stage in self.children()])
|
173 |
+
|
174 |
+
class MSC(nn.Module):
|
175 |
+
"""
|
176 |
+
Multi-scale inputs
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(self, base, scales=None):
|
180 |
+
super(MSC, self).__init__()
|
181 |
+
self.base = base
|
182 |
+
if scales:
|
183 |
+
self.scales = scales
|
184 |
+
else:
|
185 |
+
self.scales = [0.5, 0.75]
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
# Original
|
189 |
+
logits = self.base(x)
|
190 |
+
_, _, H, W = logits.shape
|
191 |
+
interp = lambda l: F.interpolate(
|
192 |
+
l, size=(H, W), mode="bilinear", align_corners=False
|
193 |
+
)
|
194 |
+
|
195 |
+
# Scaled
|
196 |
+
logits_pyramid = []
|
197 |
+
for p in self.scales:
|
198 |
+
h = F.interpolate(x, scale_factor=p, mode="bilinear", align_corners=False)
|
199 |
+
logits_pyramid.append(self.base(h))
|
200 |
+
|
201 |
+
# Pixel-wise max
|
202 |
+
logits_all = [logits] + [interp(l) for l in logits_pyramid]
|
203 |
+
logits_max = torch.max(torch.stack(logits_all), dim=0)[0]
|
204 |
+
|
205 |
+
return logits_max
|
206 |
+
|
207 |
+
class DeepLabV2(nn.Sequential):
|
208 |
+
"""
|
209 |
+
DeepLab v2: Dilated ResNet + ASPP
|
210 |
+
Output stride is fixed at 8
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(self, n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]):
|
214 |
+
super(DeepLabV2, self).__init__()
|
215 |
+
ch = [64 * 2 ** p for p in range(6)]
|
216 |
+
self.add_module("layer1", _Stem(ch[0]))
|
217 |
+
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1))
|
218 |
+
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1))
|
219 |
+
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2))
|
220 |
+
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4))
|
221 |
+
self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates))
|
222 |
+
|
223 |
+
def freeze_bn(self):
|
224 |
+
for m in self.modules():
|
225 |
+
if isinstance(m, _ConvBnReLU.BATCH_NORM):
|
226 |
+
m.eval()
|
227 |
+
|
228 |
+
def preprocessing(image, device):
|
229 |
+
# Resize
|
230 |
+
scale = 640 / max(image.shape[:2])
|
231 |
+
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
232 |
+
raw_image = image.astype(np.uint8)
|
233 |
+
|
234 |
+
# Subtract mean values
|
235 |
+
image = image.astype(np.float32)
|
236 |
+
image -= np.array(
|
237 |
+
[
|
238 |
+
float(104.008),
|
239 |
+
float(116.669),
|
240 |
+
float(122.675),
|
241 |
+
]
|
242 |
+
)
|
243 |
+
|
244 |
+
# Convert to torch.Tensor and add "batch" axis
|
245 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
|
246 |
+
image = image.to(device)
|
247 |
+
|
248 |
+
return image, raw_image
|
249 |
+
|
250 |
+
# Model setup
|
251 |
+
def seger():
|
252 |
+
model = MSC(
|
253 |
+
base=DeepLabV2(
|
254 |
+
n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]
|
255 |
+
),
|
256 |
+
scales=[0.5, 0.75],
|
257 |
+
)
|
258 |
+
state_dict = torch.load('models/deeplabv2_resnet101_msc-cocostuff164k-100000.pth')
|
259 |
+
model.load_state_dict(state_dict) # to skip ASPP
|
260 |
+
|
261 |
+
return model
|
262 |
+
|
263 |
+
if __name__ == '__main__':
|
264 |
+
device = 'cuda'
|
265 |
+
model = seger()
|
266 |
+
model.to(device)
|
267 |
+
model.eval()
|
268 |
+
with torch.no_grad():
|
269 |
+
im = cv2.imread('/group/30042/chongmou/ft_local/Diffusion/baselines/SPADE/datasets/coco_stuff/val_img/000000000785.jpg', cv2.IMREAD_COLOR)
|
270 |
+
im, raw_im = preprocessing(im, 'cuda')
|
271 |
+
_, _, H, W = im.shape
|
272 |
+
|
273 |
+
# Image -> Probability map
|
274 |
+
logits = model(im)
|
275 |
+
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
|
276 |
+
probs = F.softmax(logits, dim=1)[0]
|
277 |
+
probs = probs.cpu().data.numpy()
|
278 |
+
labelmap = np.argmax(probs, axis=0)
|
279 |
+
print(labelmap.shape, np.max(labelmap), np.min(labelmap))
|
280 |
+
cv2.imwrite('mask.png', labelmap)
|
281 |
+
|
282 |
+
|
283 |
+
|